diff --git a/cubedash/_stac.py b/cubedash/_stac.py index 6b46bb454..d9d1de845 100644 --- a/cubedash/_stac.py +++ b/cubedash/_stac.py @@ -15,6 +15,8 @@ from eodatasets3.properties import Eo3Dict from eodatasets3.utils import is_doc_eo3 from flask import abort, request +from pygeofilter.backends.cql2_json import to_cql2 +from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pystac import Catalog, Collection, Extent, ItemCollection, Link, STACObject from shapely.geometry import shape from shapely.geometry.base import BaseGeometry @@ -45,7 +47,7 @@ STAC_VERSION = "1.0.0" -ItemLike = pystac.Item | dict +ItemLike = Union[pystac.Item, dict] ############################ # Helpers @@ -398,7 +400,7 @@ def _geojson_arg(arg: dict) -> BaseGeometry: raise BadRequest("The 'intersects' argument must be valid GeoJSON geometry.") -def _bool_argument(s: str | bool): +def _bool_argument(s: Union[str, bool]): """ Parse an argument that should be a bool """ @@ -409,7 +411,7 @@ def _bool_argument(s: str | bool): return s.strip().lower() in ("1", "true", "on", "yes") -def _dict_arg(arg: str | dict): +def _dict_arg(arg: Union[str, dict]): """ Parse stac extension arguments as dicts """ @@ -418,7 +420,7 @@ def _dict_arg(arg: str | dict): return arg -def _field_arg(arg: str | list | dict): +def _field_arg(arg: Union[str, list, dict]): """ Parse field argument into a dict """ @@ -440,7 +442,7 @@ def _field_arg(arg: str | list | dict): return {"include": include, "exclude": exclude} -def _sort_arg(arg: str | list): +def _sort_arg(arg: Union[str, list]): """ Parse sortby argument into a list of dicts """ @@ -452,6 +454,8 @@ def _format(val: str) -> dict[str, str]: return {"field": val[1:], "direction": "asc"} return {"field": val.strip(), "direction": "asc"} + if isinstance(arg, str): + arg = arg.split(",") if len(arg): if isinstance(arg[0], str): return [_format(a) for a in arg] @@ -461,6 +465,19 @@ def _format(val: str) -> dict[str, str]: return arg +def _filter_arg(arg: Union[str, dict]): + # if dict, assume cql2-json and return as-is + # or do we need to use parse_cql2_json as well? + if isinstance(arg, dict): + return arg + # if json string, convert to dict + try: + return json.loads(arg) + except ValueError: + # else assume cql2-text and convert to json format + return json.loads(to_cql2(parse_cql2_text(arg))) + + # Search @@ -496,7 +513,8 @@ def _handle_search_request( sortby = request_args.get("sortby", default=None, type=_sort_arg) - filter_cql = request_args.get("filter", default=None, type=_dict_arg) + filter_cql = request_args.get("filter", default=None, type=_filter_arg) + # do we really need to return filter_lang? Or can we convert everything to cql-json if limit > PAGE_SIZE_LIMIT: abort( @@ -521,6 +539,7 @@ def next_page_url(next_offset): limit=limit, _o=next_offset, _full=full_information, + intersects=intersects, query=query, fields=fields, sortby=sortby, diff --git a/integration_tests/test_stac.py b/integration_tests/test_stac.py index 52317b5c8..bbd95c0b7 100644 --- a/integration_tests/test_stac.py +++ b/integration_tests/test_stac.py @@ -1305,6 +1305,7 @@ def test_stac_fields_extension(stac_client: FlaskClient): "properties", "stac_version", "stac_extensions", + "collection", } == keys properties = doc["features"][0]["properties"] assert {"datetime", "dea:dataset_maturity"} == set(properties.keys()) diff --git a/setup.py b/setup.py index 635409bae..34b7db56a 100755 --- a/setup.py +++ b/setup.py @@ -93,6 +93,7 @@ "sqlalchemy>=1.4", "structlog>=20.2.0", "pytz", + "pygeofilter", ], tests_require=tests_require, extras_require=extras_require,