diff --git a/arpav_ppcv/database.py b/arpav_ppcv/database.py index e1a04a8f..dea9a3d6 100644 --- a/arpav_ppcv/database.py +++ b/arpav_ppcv/database.py @@ -231,10 +231,14 @@ def list_stations( offset: int = 0, include_total: bool = False, polygon_intersection_filter: shapely.Polygon = None, + variable_id_filter: Optional[uuid.UUID] = None, + variable_aggregation_type: Optional[ + base.ObservationAggregationType + ] = base.ObservationAggregationType.SEASONAL, ) -> tuple[Sequence[observations.Station], Optional[int]]: """List existing stations. - The ``polygon_intersetion_filter`` parameter is expected to be a polygon + The ``polygon_intersection_filter`` parameter is expected to be a polygon geometry in the EPSG:4326 CRS. """ statement = sqlmodel.select(observations.Station).order_by( @@ -249,6 +253,29 @@ def list_stations( ), ) ) + if all((variable_id_filter, variable_aggregation_type)): + if variable_aggregation_type == base.ObservationAggregationType.MONTHLY: + instance_class = observations.MonthlyMeasurement + elif variable_aggregation_type == base.ObservationAggregationType.SEASONAL: + instance_class = observations.SeasonalMeasurement + elif variable_aggregation_type == base.ObservationAggregationType.YEARLY: + instance_class = observations.YearlyMeasurement + else: + raise RuntimeError( + f"variable filtering for {variable_aggregation_type} is not supported" + ) + statement = ( + statement.join(instance_class) + .join(observations.Variable) + .where(observations.Variable.id == variable_id_filter) + .distinct() + ) + + else: + logger.warning( + "Did not perform variable filter as not all related parameters have been " + "provided" + ) items = session.exec(statement.offset(offset).limit(limit)).all() num_items = _get_total_num_records(session, statement) if include_total else None return items, num_items diff --git a/arpav_ppcv/schemas/base.py b/arpav_ppcv/schemas/base.py index 563ae9c1..74dc8254 100644 --- a/arpav_ppcv/schemas/base.py +++ b/arpav_ppcv/schemas/base.py @@ -19,7 +19,7 @@ class ObservationDataSmoothingStrategy(enum.Enum): RELATED_TIME_SERIES_PATTERN = "**RELATED**" -class ObservationAggregationType(enum.Enum): +class ObservationAggregationType(str, enum.Enum): MONTHLY = "MONTHLY" SEASONAL = "SEASONAL" YEARLY = "YEARLY" diff --git a/arpav_ppcv/schemas/observations.py b/arpav_ppcv/schemas/observations.py index 7aed6f71..202589bf 100644 --- a/arpav_ppcv/schemas/observations.py +++ b/arpav_ppcv/schemas/observations.py @@ -66,6 +66,16 @@ class Station(StationBase, table=True): "passive_deletes": True, }, ) + monthly_variables: list["Variable"] = sqlmodel.Relationship( + sa_relationship_kwargs={ + "primaryjoin": ( + "and_(Station.id == MonthlyMeasurement.station_id, " + "Variable.id == MonthlyMeasurement.variable_id)" + ), + "secondary": "monthlymeasurement", + "viewonly": True, + } + ) seasonal_measurements: list["SeasonalMeasurement"] = sqlmodel.Relationship( back_populates="station", sa_relationship_kwargs={ @@ -78,6 +88,16 @@ class Station(StationBase, table=True): "passive_deletes": True, }, ) + seasonal_variables: list["Variable"] = sqlmodel.Relationship( + sa_relationship_kwargs={ + "primaryjoin": ( + "and_(Station.id == SeasonalMeasurement.station_id, " + "Variable.id == SeasonalMeasurement.variable_id)" + ), + "secondary": "seasonalmeasurement", + "viewonly": True, + } + ) yearly_measurements: list["YearlyMeasurement"] = sqlmodel.Relationship( back_populates="station", sa_relationship_kwargs={ @@ -90,6 +110,16 @@ class Station(StationBase, table=True): "passive_deletes": True, }, ) + yearly_variables: list["Variable"] = sqlmodel.Relationship( + sa_relationship_kwargs={ + "primaryjoin": ( + "and_(Station.id == YearlyMeasurement.station_id, " + "Variable.id == YearlyMeasurement.variable_id)" + ), + "secondary": "yearlymeasurement", + "viewonly": True, + } + ) class StationCreate(sqlmodel.SQLModel): diff --git a/arpav_ppcv/webapp/api_v2/routers/observations.py b/arpav_ppcv/webapp/api_v2/routers/observations.py index b527cdb3..92f3073b 100644 --- a/arpav_ppcv/webapp/api_v2/routers/observations.py +++ b/arpav_ppcv/webapp/api_v2/routers/observations.py @@ -8,9 +8,11 @@ Depends, Header, Request, + Query, ) from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse +from fastapi.exceptions import HTTPException from sqlmodel import Session from .... import database @@ -45,14 +47,32 @@ def list_stations( request: Request, db_session: Annotated[Session, Depends(dependencies.get_db_session)], list_params: Annotated[dependencies.CommonListFilterParameters, Depends()], + variable_name: str | None = None, + temporal_aggregation: Annotated[ + base.ObservationAggregationType, Query() + ] = base.ObservationAggregationType.SEASONAL, accept: Annotated[str | None, Header()] = None, ): """List known stations.""" + filter_kwargs = {} + if variable_name is not None: + if ( + db_var := database.get_variable_by_name(db_session, variable_name) + ) is not None: + filter_kwargs.update( + { + "variable_id_filter": db_var.id, + "variable_aggregation_type": temporal_aggregation, + } + ) + else: + raise HTTPException(status_code=400, detail="Invalid variable name") stations, filtered_total = database.list_stations( db_session, limit=list_params.limit, offset=list_params.offset, include_total=True, + **filter_kwargs, ) _, unfiltered_total = database.list_stations( db_session, limit=1, offset=0, include_total=True diff --git a/arpav_ppcv/webapp/api_v2/schemas/geojson/observations.py b/arpav_ppcv/webapp/api_v2/schemas/geojson/observations.py index 6675854b..d5704e4f 100644 --- a/arpav_ppcv/webapp/api_v2/schemas/geojson/observations.py +++ b/arpav_ppcv/webapp/api_v2/schemas/geojson/observations.py @@ -6,6 +6,7 @@ observations, fields, ) +from ..observations import VariableReadEmbeddedInStationRead from .base import ArpavFeatureCollection @@ -27,12 +28,26 @@ def from_db_instance( return cls( id=instance.id, geometry=instance.geom, - properties=instance.model_dump( - exclude={ - "id", - "geom", - } - ), + properties={ + **instance.model_dump( + exclude={ + "id", + "geom", + } + ), + "monthly_variables": [ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.monthly_variables + ], + "seasonal_variables": [ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.seasonal_variables + ], + "yearly_variables": [ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.yearly_variables + ], + }, links=[str(url)], ) diff --git a/arpav_ppcv/webapp/api_v2/schemas/observations.py b/arpav_ppcv/webapp/api_v2/schemas/observations.py index 792cc50f..c39ebdfd 100644 --- a/arpav_ppcv/webapp/api_v2/schemas/observations.py +++ b/arpav_ppcv/webapp/api_v2/schemas/observations.py @@ -1,4 +1,5 @@ import logging +import uuid import pydantic from fastapi import Request @@ -10,8 +11,16 @@ logger = logging.getLogger(__name__) +class VariableReadEmbeddedInStationRead(pydantic.BaseModel): + id: uuid.UUID + name: str + + class StationReadListItem(observations.StationBase): url: pydantic.AnyHttpUrl + monthly_variables: list[VariableReadEmbeddedInStationRead] + seasonal_variables: list[VariableReadEmbeddedInStationRead] + yearly_variables: list[VariableReadEmbeddedInStationRead] @classmethod def from_db_instance( @@ -22,6 +31,18 @@ def from_db_instance( url = request.url_for("get_station", **{"station_id": instance.id}) return cls( **instance.model_dump(), + monthly_variables=[ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.monthly_variables + ], + seasonal_variables=[ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.seasonal_variables + ], + yearly_variables=[ + VariableReadEmbeddedInStationRead(**v.model_dump()) + for v in instance.yearly_variables + ], url=str(url), ) diff --git a/tests/notebooks/generic.ipynb b/tests/notebooks/generic.ipynb index 1653145c..bd99633b 100644 --- a/tests/notebooks/generic.ipynb +++ b/tests/notebooks/generic.ipynb @@ -25,6 +25,7 @@ "from arpav_ppcv.schemas.base import (\n", " CoverageDataSmoothingStrategy,\n", " ObservationDataSmoothingStrategy,\n", + " ObservationAggregationType,\n", " Season,\n", ")\n", "\n", @@ -35,31 +36,37 @@ "\n", "settings = get_settings()\n", "session = sqlmodel.Session(db.get_engine(settings))\n", - "http_client = httpx.Client()\n", - "\n", - "coverage_identifier = \"uncertainty_bounds_test-rcp26-DJF\"\n", - "coverage_configuration = db.get_coverage_configuration_by_coverage_identifier(\n", - " session, coverage_identifier)\n" + "http_client = httpx.Client()" ], "outputs": [] }, { "cell_type": "code", "execution_count": 2, - "id": "f564bc8c-cf2a-410d-ba89-d3686c9aadb7", + "id": "4528e9d8-18b5-4579-b80e-423ce6bd5620", + "metadata": {}, + "source": [ + "station = db.get_station_by_code(session, \"93\")" + ], + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b09d371f-2f1f-44b4-966a-e33808e0d9fb", "metadata": {}, "source": [ - "coverage_configuration.uncertainty_lower_bounds_coverage_configuration" + "station.seasonal_variables" ], "outputs": [] }, { "cell_type": "code", - "execution_count": 3, - "id": "5ffd5df7-48b6-4822-99f5-dd60e1328d31", + "execution_count": 7, + "id": "4eb9b279-9f7a-4031-a9c3-ccc15d32ec09", "metadata": {}, "source": [ - "coverage_configuration.uncertainty_upper_bounds_coverage_configuration" + "db.collect_station_variables(session, station, ObservationAggregationType.YEARLY)" ], "outputs": [] } @@ -80,7 +87,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.14" } }, "nbformat": 4,