diff --git a/.dockerignore b/.dockerignore index e187e45..ddfce7a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,5 +9,7 @@ coverage.xml *.log .git .envrc +*egg-info -venv \ No newline at end of file +venv +env diff --git a/CHANGES.md b/CHANGES.md index 8356053..c5887e7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,7 +2,11 @@ ## [Unreleased] +<<<<<<< HEAD - Enable filter extension for `GET /items` requests and add `Queryables` links in `/collections` and `/collections/{collection_id}` responses ([#89](https://github.com/stac-utils/stac-fastapi-pgstac/pull/89)) +======= +- Add collection search extension +>>>>>>> 9b2050a (add collection search extension) ## [3.0.0a4] - 2024-07-10 diff --git a/setup.py b/setup.py index 585737d..fc79deb 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=3.0.0b2", - "stac-fastapi.extensions~=3.0.0b2", - "stac-fastapi.types~=3.0.0b2", + "stac-fastapi.api~=3.0.0b3", + "stac-fastapi.extensions~=3.0.0b3", + "stac-fastapi.types~=3.0.0b3", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 924ea01..f798198 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -10,12 +10,14 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import ( + EmptyRequest, ItemCollectionUri, create_get_request_model, create_post_request_model, create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -47,12 +49,26 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } +collections_extensions_map = { + "collection_search": CollectionSearchExtension(), +} + if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): + _enabled_extensions = enabled_extensions.split(",") extensions = [ - extensions_map[extension_name] for extension_name in enabled_extensions.split(",") + extension + for key, extension in extensions_map.items() + if key in _enabled_extensions + ] + collection_extensions = [ + extension + for key, extension in collections_extensions_map.items() + if key in _enabled_extensions ] else: extensions = list(extensions_map.values()) + collection_extensions = list(collections_extensions_map.values()) + if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): items_get_request_model = create_request_model( @@ -64,17 +80,23 @@ else: items_get_request_model = ItemCollectionUri +if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): + collections_get_request_model = CollectionSearchExtension().GET +else: + collections_get_request_model = EmptyRequest + post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) api = StacApi( settings=settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, search_post_request_model=post_request_model, + collections_get_request_model=collections_get_request_model, ) app = api.app diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 648d42a..6ec8fad 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -39,17 +39,113 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - async def all_collections(self, request: Request, **kwargs) -> Collections: - """Read all collections from the database.""" + async def all_collections( # noqa: C901 + self, + request: Request, + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = None, + # Extensions + query: Optional[str] = None, + token: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /collections`. + + Returns: + Collections which match the search criteria, returns all + collections by default. + """ + query_params = str(request.query_params) + + # Kludgy fix because using factory does not allow alias for filter-lang + if filter_lang is None: + match = re.search(r"filter-lang=([a-z0-9-]+)", query_params, re.IGNORECASE) + if match: + filter_lang = match.group(1) + + # Parse request parameters + base_args = { + "bbox": bbox, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) + + # Do the request + try: + search_request = self.post_request_model(**clean) + except ValidationError as e: + raise HTTPException( + status_code=400, detail=f"Invalid parameters provided {e}" + ) from e + + return await self._collection_search_base(search_request, request=request) + + async def _collection_search_base( # noqa: C901 + self, + search_request: PgstacSearch, + request: Request, + ) -> Collections: + """Cross catalog search (POST). + + Called with `POST /search`. + + Args: + search_request: search request parameters. + + Returns: + All collections which match the search criteria. + """ + base_url = get_base_url(request) - async with request.app.state.get_connection(request, "r") as conn: - collections = await conn.fetchval( - """ - SELECT * FROM all_collections(); - """ - ) + settings: Settings = request.app.state.settings + + if search_request.datetime: + search_request.datetime = format_datetime_range(search_request.datetime) + + search_request.conf = search_request.conf or {} + search_request.conf["nohydrate"] = settings.use_api_hydrate + + search_request_json = search_request.model_dump_json( + exclude_none=True, by_alias=True + ) + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=search_request_json, + ) + collections_result: Collections = await conn.fetchval(q, *p) + except InvalidDatetimeFormatError as e: + raise InvalidQueryParameter( + f"Datetime parameter {search_request.datetime} is invalid." + ) from e + + # next: Optional[str] = collections_result["links"].pop("next") + # prev: Optional[str] = collections_result["links"].pop("prev") + linked_collections: List[Collection] = [] + collections = collections_result["collections"] if collections is not None and len(collections) > 0: for c in collections: coll = Collection(**c) @@ -71,6 +167,12 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: linked_collections.append(coll) + # paging_links = await PagingLinks( + # request=request, + # next=next, + # prev=prev, + # ).get_links() + links = [ { "rel": Relations.root.value, @@ -88,8 +190,10 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: "href": urljoin(base_url, "collections"), }, ] - collection_list = Collections(collections=linked_collections or [], links=links) - return collection_list + return Collections( + collections=linked_collections or [], + links=links, # + paging_links + ) async def get_collection( self, collection_id: str, request: Request, **kwargs @@ -383,7 +487,7 @@ async def post_search( return ItemCollection(**item_collection) - async def get_search( # noqa: C901 + async def get_search( self, request: Request, collections: Optional[List[str]] = None, @@ -418,49 +522,15 @@ async def get_search( # noqa: C901 "query": orjson.loads(unquote_plus(query)) if query else query, } - if filter: - if filter_lang == "cql2-text": - ast = parse_cql2_text(filter) - base_args["filter"] = orjson.loads(to_cql2(ast)) - base_args["filter-lang"] = "cql2-json" - - if datetime: - base_args["datetime"] = format_datetime_range(datetime) - - if intersects: - base_args["intersects"] = orjson.loads(unquote_plus(intersects)) - - if sortby: - # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form - sort_param = [] - for sort in sortby: - sortparts = re.match(r"^([+-]?)(.*)$", sort) - if sortparts: - sort_param.append( - { - "field": sortparts.group(2).strip(), - "direction": "desc" if sortparts.group(1) == "-" else "asc", - } - ) - base_args["sortby"] = sort_param - - if fields: - includes = set() - excludes = set() - for field in fields: - if field[0] == "-": - excludes.add(field[1:]) - elif field[0] == "+": - includes.add(field[1:]) - else: - includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} - - # Remove None values from dict - clean = {} - for k, v in base_args.items(): - if v is not None and v != []: - clean[k] = v + clean = clean_search_args( + base_args=base_args, + intersects=intersects, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) # Do the request try: @@ -471,3 +541,60 @@ async def get_search( # noqa: C901 ) from e return await self.post_search(search_request, request=request) + + +def clean_search_args( # noqa: C901 + base_args: dict[str, Any], + intersects: Optional[str] = None, + datetime: Optional[DateTimeType] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, +) -> dict[str, Any]: + """Clean up search arguments to match format expected by pgstac""" + if filter: + if filter_lang == "cql2-text": + ast = parse_cql2_text(filter) + base_args["filter"] = orjson.loads(to_cql2(ast)) + base_args["filter-lang"] = "cql2-json" + + if datetime: + base_args["datetime"] = format_datetime_range(datetime) + + if intersects: + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) + + if sortby: + # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form + sort_param = [] + for sort in sortby: + sortparts = re.match(r"^([+-]?)(.*)$", sort) + if sortparts: + sort_param.append( + { + "field": sortparts.group(2).strip(), + "direction": "desc" if sortparts.group(1) == "-" else "asc", + } + ) + base_args["sortby"] = sort_param + + if fields: + includes = set() + excludes = set() + for field in fields: + if field[0] == "-": + excludes.add(field[1:]) + elif field[0] == "+": + includes.add(field[1:]) + else: + includes.add(field) + base_args["fields"] = {"include": includes, "exclude": excludes} + + # Remove None values from dict + clean = {} + for k, v in base_args.items(): + if v is not None and v != []: + clean[k] = v + + return clean diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 2077c35..fddc9a3 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -502,6 +502,33 @@ async def test_collection_queryables(load_test_data, app_client, load_test_colle assert "id" in q["properties"] +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 + + @pytest.mark.asyncio async def test_item_collection_filter_bbox( load_test_data, app_client, load_test_collection diff --git a/tests/conftest.py b/tests/conftest.py index e5955b7..8f9480c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -133,6 +134,7 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] + collection_extensions = [CollectionSearchExtension()] items_get_request_model = create_request_model( model_name="ItemCollectionUri", @@ -149,11 +151,12 @@ def api_client(request, database): ) api = StacApi( settings=api_settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, + collections_get_request_model=CollectionSearchExtension().GET, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), ) diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 3a2183b..634747b 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -276,3 +276,30 @@ async def test_get_collections_queryables_links(app_client, load_test_collection f"/collections/{collection_id}", ) assert "Queryables" in [link.get("title") for link in resp.json()["links"]] + + +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2