Skip to content

Commit

Permalink
Formatted the code
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardogsilva committed May 27, 2024
1 parent abde69f commit 2ce2c1e
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 83 deletions.
8 changes: 2 additions & 6 deletions arpav_ppcv/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,9 @@ def update_station(
station_update: observations.StationUpdate,
) -> observations.Station:
"""Update a station."""
geom = from_shape(
shapely.io.from_geojson(station_update.geom.model_dump_json()))
geom = from_shape(shapely.io.from_geojson(station_update.geom.model_dump_json()))
other_data = station_update.model_dump(exclude={"geom"}, exclude_unset=True)
data = {
**other_data,
"geom": geom
}
data = {**other_data, "geom": geom}
for key, value in data.items():
setattr(db_station, key, value)
session.add(db_station)
Expand Down
22 changes: 16 additions & 6 deletions arpav_ppcv/webapp/admin/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,29 @@ def create_admin(settings: config.ArpavPpcvSettings) -> ArpavPpcvAdmin:
Middleware(SqlModelDbSessionMiddleware, engine=engine),
],
)
admin.add_view(coverage_views.ConfigurationParameterView(coverages.ConfigurationParameter))
admin.add_view(coverage_views.CoverageConfigurationView(coverages.CoverageConfiguration))
admin.add_view(
coverage_views.ConfigurationParameterView(coverages.ConfigurationParameter)
)
admin.add_view(
coverage_views.CoverageConfigurationView(coverages.CoverageConfiguration)
)
admin.add_view(observations_views.VariableView(observations.Variable))
admin.add_view(observations_views.StationView(observations.Station))
admin.add_view(
DropDown(
"Measurements",
icon="fa-solid fa-vials",
views=[
observations_views.MonthlyMeasurementView(observations.MonthlyMeasurement),
observations_views.SeasonalMeasurementView(observations.SeasonalMeasurement),
observations_views.YearlyMeasurementView(observations.YearlyMeasurement),
]
observations_views.MonthlyMeasurementView(
observations.MonthlyMeasurement
),
observations_views.SeasonalMeasurementView(
observations.SeasonalMeasurement
),
observations_views.YearlyMeasurementView(
observations.YearlyMeasurement
),
],
)
)
admin.add_view(
Expand Down
146 changes: 75 additions & 71 deletions arpav_ppcv/webapp/admin/views/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,21 @@ def can_view_details(self, request: Request) -> bool:

@staticmethod
def _serialize_instance(
instance: observations.MonthlyMeasurement) -> read_schemas.MonthlyMeasurementRead:
instance: observations.MonthlyMeasurement,
) -> read_schemas.MonthlyMeasurementRead:
return read_schemas.MonthlyMeasurementRead(
**instance.model_dump(),
station=instance.station.code,
variable=instance.variable.name,
)

async def find_all(
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
) -> Sequence[read_schemas.MonthlyMeasurementRead]:
list_measurements = functools.partial(
db.list_monthly_measurements,
Expand All @@ -77,7 +78,8 @@ async def find_all(
include_total=False,
)
db_measurements, _ = await anyio.to_thread.run_sync(
list_measurements, request.state.session)
list_measurements, request.state.session
)
return [self._serialize_instance(item) for item in db_measurements]


Expand Down Expand Up @@ -111,20 +113,21 @@ def can_view_details(self, request: Request) -> bool:

@staticmethod
def _serialize_instance(
instance: observations.SeasonalMeasurement) -> read_schemas.SeasonalMeasurementRead:
instance: observations.SeasonalMeasurement,
) -> read_schemas.SeasonalMeasurementRead:
return read_schemas.SeasonalMeasurementRead(
**instance.model_dump(),
station=instance.station.code,
variable=instance.variable.name,
)

async def find_all(
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
) -> Sequence[read_schemas.SeasonalMeasurementRead]:
list_measurements = functools.partial(
db.list_seasonal_measurements,
Expand All @@ -133,7 +136,8 @@ async def find_all(
include_total=False,
)
db_measurements, _ = await anyio.to_thread.run_sync(
list_measurements, request.state.session)
list_measurements, request.state.session
)
return [self._serialize_instance(item) for item in db_measurements]


Expand Down Expand Up @@ -166,20 +170,21 @@ def can_view_details(self, request: Request) -> bool:

@staticmethod
def _serialize_instance(
instance: observations.YearlyMeasurement) -> read_schemas.YearlyMeasurementRead:
instance: observations.YearlyMeasurement,
) -> read_schemas.YearlyMeasurementRead:
return read_schemas.YearlyMeasurementRead(
**instance.model_dump(),
station=instance.station.code,
variable=instance.variable.name,
)

async def find_all(
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
) -> Sequence[read_schemas.YearlyMeasurementRead]:
list_measurements = functools.partial(
db.list_yearly_measurements,
Expand All @@ -188,7 +193,8 @@ async def find_all(
include_total=False,
)
db_measurements, _ = await anyio.to_thread.run_sync(
list_measurements, request.state.session)
list_measurements, request.state.session
)
return [self._serialize_instance(item) for item in db_measurements]


Expand All @@ -215,7 +221,8 @@ def __init__(self, *args, **kwargs) -> None:

@staticmethod
def _serialize_instance(
instance: observations.Variable) -> read_schemas.VariableRead:
instance: observations.Variable,
) -> read_schemas.VariableRead:
return read_schemas.VariableRead(**instance.model_dump())

async def get_pk_value(self, request: Request, obj: Any) -> str:
Expand All @@ -227,7 +234,7 @@ async def get_pk_value(self, request: Request, obj: Any) -> str:
return str(result)

async def create(
self, request: Request, data: dict[str, Any]
self, request: Request, data: dict[str, Any]
) -> Optional[read_schemas.VariableRead]:
try:
data = await self._arrange_data(request, data)
Expand All @@ -243,7 +250,7 @@ async def create(
return self.handle_exception(e)

async def edit(
self, request: Request, pk: Any, data: dict[str, Any]
self, request: Request, pk: Any, data: dict[str, Any]
) -> Optional[read_schemas.VariableRead]:
try:
data = await self._arrange_data(request, data, True)
Expand All @@ -253,26 +260,26 @@ async def edit(
db.get_variable, request.state.session, pk
)
db_var = await anyio.to_thread.run_sync(
db.update_variable, request.state.session, db_var, var_update)
db.update_variable, request.state.session, db_var, var_update
)
return self._serialize_instance(db_var)
except Exception as e:
logger.exception("something went wrong")
self.handle_exception(e)

async def find_by_pk(
self, request: Request, pk: Any
) -> read_schemas.VariableRead:
async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.VariableRead:
db_var = await anyio.to_thread.run_sync(
db.get_variable, request.state.session, pk)
db.get_variable, request.state.session, pk
)
return self._serialize_instance(db_var)

async def find_all(
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
) -> Sequence[read_schemas.VariableRead]:
list_variables = functools.partial(
db.list_variables,
Expand All @@ -281,7 +288,8 @@ async def find_all(
include_total=False,
)
db_vars, _ = await anyio.to_thread.run_sync(
list_variables, request.state.session)
list_variables, request.state.session
)
return [self._serialize_instance(db_var) for db_var in db_vars]


Expand Down Expand Up @@ -312,16 +320,10 @@ def __init__(self, *args, **kwargs) -> None:
self.icon = "fa-solid fa-tower-observation"

@staticmethod
def _serialize_instance(
instance: observations.Station) -> read_schemas.StationRead:
def _serialize_instance(instance: observations.Station) -> read_schemas.StationRead:
geom = shapely.io.from_wkb(bytes(instance.geom.data))
return read_schemas.StationRead(
**instance.model_dump(
exclude={
"geom",
"type_"
}
),
**instance.model_dump(exclude={"geom", "type_"}),
type=instance.type_,
longitude=geom.x,
latitude=geom.y,
Expand Down Expand Up @@ -352,21 +354,26 @@ async def validate(self, request: Request, data: dict[str, Any]) -> None:
fields_to_exclude = [
f.name
for f in self.get_fields_list(request, request.state.action)
if isinstance(f, (starlette_admin.FileField, starlette_admin.RelationField))
if isinstance(
f, (starlette_admin.FileField, starlette_admin.RelationField)
)
] + ["latitude", "longitude"]
self.model.validate(
{k: v for k, v in data_to_validate.items() if k not in fields_to_exclude}
{
k: v
for k, v in data_to_validate.items()
if k not in fields_to_exclude
}
)

async def create(
self, request: Request, data: dict[str, Any]
self, request: Request, data: dict[str, Any]
) -> Optional[read_schemas.StationRead]:
try:
data = await self._arrange_data(request, data)
await self.validate(request, data)
geojson_geom = geojson_pydantic.Point(
type="Point",
coordinates=(data.pop("longitude"), data.pop("latitude"))
type="Point", coordinates=(data.pop("longitude"), data.pop("latitude"))
)
station_create = observations.StationCreate(
type_=data.pop("type"),
Expand All @@ -384,7 +391,7 @@ async def create(
return self.handle_exception(e)

async def edit(
self, request: Request, pk: Any, data: dict[str, Any]
self, request: Request, pk: Any, data: dict[str, Any]
) -> Optional[read_schemas.StationRead]:
try:
data = await self._arrange_data(request, data, True)
Expand All @@ -394,39 +401,35 @@ async def edit(
kwargs = {}
if all((lon, lat)):
kwargs["geom"] = geojson_pydantic.Point(
type="Point",
coordinates=(lon, lat)
type="Point", coordinates=(lon, lat)
)
if (type_ := data.pop("type", None)) is not None:
kwargs["type_"] = type_
station_update = observations.StationUpdate(
**data,
**kwargs
)
station_update = observations.StationUpdate(**data, **kwargs)
db_station = await anyio.to_thread.run_sync(
db.get_station, request.state.session, pk
)
db_station = await anyio.to_thread.run_sync(
db.update_station, request.state.session, db_station, station_update)
db.update_station, request.state.session, db_station, station_update
)
return self._serialize_instance(db_station)
except Exception as e:
logger.exception("something went wrong")
self.handle_exception(e)

async def find_by_pk(
self, request: Request, pk: Any
) -> read_schemas.StationRead:
async def find_by_pk(self, request: Request, pk: Any) -> read_schemas.StationRead:
db_station = await anyio.to_thread.run_sync(
db.get_station, request.state.session, pk)
db.get_station, request.state.session, pk
)
return self._serialize_instance(db_station)

async def find_all(
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
self,
request: Request,
skip: int = 0,
limit: int = 100,
where: Union[dict[str, Any], str, None] = None,
order_by: Optional[list[str]] = None,
) -> Sequence[read_schemas.StationRead]:
list_stations = functools.partial(
db.list_stations,
Expand All @@ -435,5 +438,6 @@ async def find_all(
include_total=False,
)
db_stations, _ = await anyio.to_thread.run_sync(
list_stations, request.state.session)
list_stations, request.state.session
)
return [self._serialize_instance(db_station) for db_station in db_stations]

0 comments on commit 2ce2c1e

Please sign in to comment.