From c25229ed50571d56039388c523ed974f4babd148 Mon Sep 17 00:00:00 2001 From: Stijn Caerts Date: Thu, 5 Sep 2024 11:19:46 +0200 Subject: [PATCH] refactor aggregate in database_logic (#294) **Related Issue(s):** N/A **Description:** Refactor `aggregate()` in database logic to allow extending the supported set of aggregations. The mapping of aggregation name to Elasticsearch/OpenSearch functionality was in the `aggregate()` function, which made it difficult to alter the set of supported aggregations. I moved the mapping to a property of the database logic, so it can be modified when the database logic is instantiated. **PR Checklist:** - [x] Code is formatted and linted (run `pre-commit run --all-files`) - [x] Tests pass (run `make test`) - [x] Documentation has been updated to reflect changes, if applicable - [x] Changes are added to the changelog --------- Co-authored-by: Jonathan Healy --- CHANGELOG.md | 4 + .../core/extensions/aggregation.py | 35 ++++ .../elasticsearch/database_logic.py | 177 +++++++++-------- .../stac_fastapi/opensearch/database_logic.py | 179 +++++++++--------- 4 files changed, 225 insertions(+), 170 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 299d8ba1..75d12221 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ### Added +- Added `datetime_frequency_interval` parameter for `datetime_frequency` aggregation. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294) + ### Changed + +- Refactored aggregation in database logic. [#294](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/294) - Fixed the `self` link for the `/collections/{collection_id}/aggregations` endpoint. [#295](https://github.com/stac-utils/stac-fastapi-elasticsearch-opensearch/pull/295) ## [v3.1.0] - 2024-09-02 diff --git a/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py b/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py index 27f6b458..2cf880c9 100644 --- a/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py +++ b/stac_fastapi/core/stac_fastapi/core/extensions/aggregation.py @@ -50,6 +50,7 @@ class EsAggregationExtensionGetRequest( centroid_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None) geometry_geohash_grid_frequency_precision: Optional[int] = attr.ib(default=None) geometry_geotile_grid_frequency_precision: Optional[int] = attr.ib(default=None) + datetime_frequency_interval: Optional[str] = attr.ib(default=None) class EsAggregationExtensionPostRequest( @@ -62,6 +63,7 @@ class EsAggregationExtensionPostRequest( centroid_geotile_grid_frequency_precision: Optional[int] = None geometry_geohash_grid_frequency_precision: Optional[int] = None geometry_geotile_grid_frequency_precision: Optional[int] = None + datetime_frequency_interval: Optional[str] = None @attr.s @@ -124,6 +126,8 @@ class EsAsyncAggregationClient(AsyncBaseAggregationClient): MAX_GEOHASH_PRECISION = 12 MAX_GEOHEX_PRECISION = 15 MAX_GEOTILE_PRECISION = 29 + SUPPORTED_DATETIME_INTERVAL = {"day", "month", "year"} + DEFAULT_DATETIME_INTERVAL = "month" async def get_aggregations(self, collection_id: Optional[str] = None, **kwargs): """Get the available aggregations for a catalog or collection defined in the STAC JSON. If no aggregations, default aggregations are used.""" @@ -182,6 +186,30 @@ def extract_precision( else: return min_value + def extract_date_histogram_interval(self, value: Optional[str]) -> str: + """ + Ensure that the interval for the date histogram is valid. If no value is provided, the default will be returned. + + Args: + value: value entered by the user + + Returns: + string containing the date histogram interval to use. + + Raises: + HTTPException: if the supplied value is not in the supported intervals + """ + if value is not None: + if value not in self.SUPPORTED_DATETIME_INTERVAL: + raise HTTPException( + status_code=400, + detail=f"Invalid datetime interval. Must be one of {self.SUPPORTED_DATETIME_INTERVAL}", + ) + else: + return value + else: + return self.DEFAULT_DATETIME_INTERVAL + @staticmethod def _return_date( interval: Optional[Union[DateTimeType, str]] @@ -319,6 +347,7 @@ async def aggregate( centroid_geotile_grid_frequency_precision: Optional[int] = None, geometry_geohash_grid_frequency_precision: Optional[int] = None, geometry_geotile_grid_frequency_precision: Optional[int] = None, + datetime_frequency_interval: Optional[str] = None, **kwargs, ) -> Union[Dict, Exception]: """Get aggregations from the database.""" @@ -339,6 +368,7 @@ async def aggregate( "centroid_geotile_grid_frequency_precision": centroid_geotile_grid_frequency_precision, "geometry_geohash_grid_frequency_precision": geometry_geohash_grid_frequency_precision, "geometry_geotile_grid_frequency_precision": geometry_geotile_grid_frequency_precision, + "datetime_frequency_interval": datetime_frequency_interval, } if collection_id: @@ -475,6 +505,10 @@ async def aggregate( self.MAX_GEOTILE_PRECISION, ) + datetime_frequency_interval = self.extract_date_histogram_interval( + aggregate_request.datetime_frequency_interval, + ) + try: db_response = await self.database.aggregate( collections, @@ -485,6 +519,7 @@ async def aggregate( centroid_geotile_grid_precision, geometry_geohash_grid_precision, geometry_geotile_grid_precision, + datetime_frequency_interval, ) except Exception as error: if not isinstance(error, IndexError): diff --git a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py index 7aa887b5..da6d6880 100644 --- a/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py +++ b/stac_fastapi/elasticsearch/stac_fastapi/elasticsearch/database_logic.py @@ -4,6 +4,7 @@ import logging import os from base64 import urlsafe_b64decode, urlsafe_b64encode +from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union import attr @@ -316,6 +317,77 @@ class DatabaseLogic: extensions: List[str] = attr.ib(default=attr.Factory(list)) + aggregation_mapping: Dict[str, Dict[str, Any]] = { + "total_count": {"value_count": {"field": "id"}}, + "collection_frequency": {"terms": {"field": "collection", "size": 100}}, + "platform_frequency": {"terms": {"field": "properties.platform", "size": 100}}, + "cloud_cover_frequency": { + "range": { + "field": "properties.eo:cloud_cover", + "ranges": [ + {"to": 5}, + {"from": 5, "to": 15}, + {"from": 15, "to": 40}, + {"from": 40}, + ], + } + }, + "datetime_frequency": { + "date_histogram": { + "field": "properties.datetime", + "calendar_interval": "month", + } + }, + "datetime_min": {"min": {"field": "properties.datetime"}}, + "datetime_max": {"max": {"field": "properties.datetime"}}, + "grid_code_frequency": { + "terms": { + "field": "properties.grid:code", + "missing": "none", + "size": 10000, + } + }, + "sun_elevation_frequency": { + "histogram": {"field": "properties.view:sun_elevation", "interval": 5} + }, + "sun_azimuth_frequency": { + "histogram": {"field": "properties.view:sun_azimuth", "interval": 5} + }, + "off_nadir_frequency": { + "histogram": {"field": "properties.view:off_nadir", "interval": 5} + }, + "centroid_geohash_grid_frequency": { + "geohash_grid": { + "field": "properties.proj:centroid", + "precision": 1, + } + }, + "centroid_geohex_grid_frequency": { + "geohex_grid": { + "field": "properties.proj:centroid", + "precision": 0, + } + }, + "centroid_geotile_grid_frequency": { + "geotile_grid": { + "field": "properties.proj:centroid", + "precision": 0, + } + }, + "geometry_geohash_grid_frequency": { + "geohash_grid": { + "field": "geometry", + "precision": 1, + } + }, + "geometry_geotile_grid_frequency": { + "geotile_grid": { + "field": "geometry", + "precision": 0, + } + }, + } + """CORE LOGIC""" async def get_all_collections( @@ -657,52 +729,10 @@ async def aggregate( centroid_geotile_grid_precision: int, geometry_geohash_grid_precision: int, geometry_geotile_grid_precision: int, + datetime_frequency_interval: str, ignore_unavailable: Optional[bool] = True, ): """Return aggregations of STAC Items.""" - agg_2_es = { - "total_count": {"value_count": {"field": "id"}}, - "collection_frequency": {"terms": {"field": "collection", "size": 100}}, - "platform_frequency": { - "terms": {"field": "properties.platform", "size": 100} - }, - "cloud_cover_frequency": { - "range": { - "field": "properties.eo:cloud_cover", - "ranges": [ - {"to": 5}, - {"from": 5, "to": 15}, - {"from": 15, "to": 40}, - {"from": 40}, - ], - } - }, - "datetime_frequency": { - "date_histogram": { - "field": "properties.datetime", - "calendar_interval": "month", - } - }, - "datetime_min": {"min": {"field": "properties.datetime"}}, - "datetime_max": {"max": {"field": "properties.datetime"}}, - "grid_code_frequency": { - "terms": { - "field": "properties.grid:code", - "missing": "none", - "size": 10000, - } - }, - "sun_elevation_frequency": { - "histogram": {"field": "properties.view:sun_elevation", "interval": 5} - }, - "sun_azimuth_frequency": { - "histogram": {"field": "properties.view:sun_azimuth", "interval": 5} - }, - "off_nadir_frequency": { - "histogram": {"field": "properties.view:off_nadir", "interval": 5} - }, - } - search_body: Dict[str, Any] = {} query = search.query.to_dict() if search.query else None if query: @@ -710,51 +740,30 @@ async def aggregate( logger.debug("Aggregations: %s", aggregations) - # include all aggregations specified - # this will ignore aggregations with the wrong names - search_body["aggregations"] = { - k: v for k, v in agg_2_es.items() if k in aggregations - } - - if "centroid_geohash_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geohash_grid_frequency"] = { - "geohash_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geohash_grid_precision, - } - } - - if "centroid_geohex_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geohex_grid_frequency"] = { - "geohex_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geohex_grid_precision, - } + def _fill_aggregation_parameters(name: str, agg: dict) -> dict: + [key] = agg.keys() + agg_precision = { + "centroid_geohash_grid_frequency": centroid_geohash_grid_precision, + "centroid_geohex_grid_frequency": centroid_geohex_grid_precision, + "centroid_geotile_grid_frequency": centroid_geotile_grid_precision, + "geometry_geohash_grid_frequency": geometry_geohash_grid_precision, + "geometry_geotile_grid_frequency": geometry_geotile_grid_precision, } + if name in agg_precision: + agg[key]["precision"] = agg_precision[name] - if "centroid_geotile_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geotile_grid_frequency"] = { - "geotile_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geotile_grid_precision, - } - } + if key == "date_histogram": + agg[key]["calendar_interval"] = datetime_frequency_interval - if "geometry_geohash_grid_frequency" in aggregations: - search_body["aggregations"]["geometry_geohash_grid_frequency"] = { - "geohash_grid": { - "field": "geometry", - "precision": geometry_geohash_grid_precision, - } - } + return agg - if "geometry_geotile_grid_frequency" in aggregations: - search_body["aggregations"]["geometry_geotile_grid_frequency"] = { - "geotile_grid": { - "field": "geometry", - "precision": geometry_geotile_grid_precision, - } - } + # include all aggregations specified + # this will ignore aggregations with the wrong names + search_body["aggregations"] = { + k: _fill_aggregation_parameters(k, deepcopy(v)) + for k, v in self.aggregation_mapping.items() + if k in aggregations + } index_param = indices(collection_ids) search_task = asyncio.create_task( diff --git a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py index 014ea57b..778cfe03 100644 --- a/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py +++ b/stac_fastapi/opensearch/stac_fastapi/opensearch/database_logic.py @@ -4,6 +4,7 @@ import logging import os from base64 import urlsafe_b64decode, urlsafe_b64encode +from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Protocol, Tuple, Type, Union import attr @@ -337,6 +338,77 @@ class DatabaseLogic: extensions: List[str] = attr.ib(default=attr.Factory(list)) + aggregation_mapping: Dict[str, Dict[str, Any]] = { + "total_count": {"value_count": {"field": "id"}}, + "collection_frequency": {"terms": {"field": "collection", "size": 100}}, + "platform_frequency": {"terms": {"field": "properties.platform", "size": 100}}, + "cloud_cover_frequency": { + "range": { + "field": "properties.eo:cloud_cover", + "ranges": [ + {"to": 5}, + {"from": 5, "to": 15}, + {"from": 15, "to": 40}, + {"from": 40}, + ], + } + }, + "datetime_frequency": { + "date_histogram": { + "field": "properties.datetime", + "calendar_interval": "month", + } + }, + "datetime_min": {"min": {"field": "properties.datetime"}}, + "datetime_max": {"max": {"field": "properties.datetime"}}, + "grid_code_frequency": { + "terms": { + "field": "properties.grid:code", + "missing": "none", + "size": 10000, + } + }, + "sun_elevation_frequency": { + "histogram": {"field": "properties.view:sun_elevation", "interval": 5} + }, + "sun_azimuth_frequency": { + "histogram": {"field": "properties.view:sun_azimuth", "interval": 5} + }, + "off_nadir_frequency": { + "histogram": {"field": "properties.view:off_nadir", "interval": 5} + }, + "centroid_geohash_grid_frequency": { + "geohash_grid": { + "field": "properties.proj:centroid", + "precision": 1, + } + }, + "centroid_geohex_grid_frequency": { + "geohex_grid": { + "field": "properties.proj:centroid", + "precision": 0, + } + }, + "centroid_geotile_grid_frequency": { + "geotile_grid": { + "field": "properties.proj:centroid", + "precision": 0, + } + }, + "geometry_geohash_grid_frequency": { + "geohash_grid": { + "field": "geometry", + "precision": 1, + } + }, + "geometry_geotile_grid_frequency": { + "geotile_grid": { + "field": "geometry", + "precision": 0, + } + }, + } + """CORE LOGIC""" async def get_all_collections( @@ -689,104 +761,39 @@ async def aggregate( centroid_geotile_grid_precision: int, geometry_geohash_grid_precision: int, geometry_geotile_grid_precision: int, + datetime_frequency_interval: str, ignore_unavailable: Optional[bool] = True, ): """Return aggregations of STAC Items.""" - agg_2_es = { - "total_count": {"value_count": {"field": "id"}}, - "collection_frequency": {"terms": {"field": "collection", "size": 100}}, - "platform_frequency": { - "terms": {"field": "properties.platform", "size": 100} - }, - "cloud_cover_frequency": { - "range": { - "field": "properties.eo:cloud_cover", - "ranges": [ - {"to": 5}, - {"from": 5, "to": 15}, - {"from": 15, "to": 40}, - {"from": 40}, - ], - } - }, - "datetime_frequency": { - "date_histogram": { - "field": "properties.datetime", - "calendar_interval": "month", - } - }, - "datetime_min": {"min": {"field": "properties.datetime"}}, - "datetime_max": {"max": {"field": "properties.datetime"}}, - "grid_code_frequency": { - "terms": { - "field": "properties.grid:code", - "missing": "none", - "size": 10000, - } - }, - "sun_elevation_frequency": { - "histogram": {"field": "properties.view:sun_elevation", "interval": 5} - }, - "sun_azimuth_frequency": { - "histogram": {"field": "properties.view:sun_azimuth", "interval": 5} - }, - "off_nadir_frequency": { - "histogram": {"field": "properties.view:off_nadir", "interval": 5} - }, - } - search_body: Dict[str, Any] = {} query = search.query.to_dict() if search.query else None if query: search_body["query"] = query - # include all aggregations specified - # this will ignore aggregations with the wrong names - search_body["aggregations"] = { - k: v for k, v in agg_2_es.items() if k in aggregations - } - - # centroid - if "centroid_geohash_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geohash_grid_frequency"] = { - "geohash_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geohash_grid_precision, - } - } - - if "centroid_geohex_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geohex_grid_frequency"] = { - "geohex_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geohex_grid_precision, - } + def _fill_aggregation_parameters(name: str, agg: dict) -> dict: + [key] = agg.keys() + agg_precision = { + "centroid_geohash_grid_frequency": centroid_geohash_grid_precision, + "centroid_geohex_grid_frequency": centroid_geohex_grid_precision, + "centroid_geotile_grid_frequency": centroid_geotile_grid_precision, + "geometry_geohash_grid_frequency": geometry_geohash_grid_precision, + "geometry_geotile_grid_frequency": geometry_geotile_grid_precision, } + if name in agg_precision: + agg[key]["precision"] = agg_precision[name] - if "centroid_geotile_grid_frequency" in aggregations: - search_body["aggregations"]["centroid_geotile_grid_frequency"] = { - "geotile_grid": { - "field": "properties.proj:centroid", - "precision": centroid_geotile_grid_precision, - } - } + if key == "date_histogram": + agg[key]["calendar_interval"] = datetime_frequency_interval - # geometry - if "geometry_geohash_grid_frequency" in aggregations: - search_body["aggregations"]["geometry_geohash_grid_frequency"] = { - "geohash_grid": { - "field": "geometry", - "precision": geometry_geohash_grid_precision, - } - } + return agg - if "geometry_geotile_grid_frequency" in aggregations: - search_body["aggregations"]["geometry_geotile_grid_frequency"] = { - "geotile_grid": { - "field": "geometry", - "precision": geometry_geotile_grid_precision, - } - } + # include all aggregations specified + # this will ignore aggregations with the wrong names + search_body["aggregations"] = { + k: _fill_aggregation_parameters(k, deepcopy(v)) + for k, v in self.aggregation_mapping.items() + if k in aggregations + } index_param = indices(collection_ids) search_task = asyncio.create_task(