diff --git a/.dockerignore b/.dockerignore index e187e45a..ddfce7a0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,5 +9,7 @@ coverage.xml *.log .git .envrc +*egg-info -venv \ No newline at end of file +venv +env diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 2618036c..d4a03663 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -47,7 +47,7 @@ jobs: runs-on: ubuntu-latest services: pgstac: - image: ghcr.io/stac-utils/pgstac:v0.7.10 + image: ghcr.io/stac-utils/pgstac:v0.8.6 env: POSTGRES_USER: username POSTGRES_PASSWORD: password diff --git a/CHANGES.md b/CHANGES.md index c9c124e5..9e43b5e2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,8 @@ ## [Unreleased] +- Add collection search extension + ## [3.0.0] - 2024-08-02 - Enable filter extension for `GET /items` requests and add `Queryables` links in `/collections` and `/collections/{collection_id}` responses ([#89](https://github.com/stac-utils/stac-fastapi-pgstac/pull/89)) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 924ea01f..10b05377 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -10,12 +10,14 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import ( + EmptyRequest, ItemCollectionUri, create_get_request_model, create_post_request_model, create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -47,12 +49,26 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } +collections_extensions_map = { + "collection_search": CollectionSearchExtension(), +} + if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): + _enabled_extensions = enabled_extensions.split(",") extensions = [ - extensions_map[extension_name] for extension_name in enabled_extensions.split(",") + extension + for key, extension in extensions_map.items() + if key in _enabled_extensions + ] + collection_extensions = [ + extension + for key, extension in collections_extensions_map.items() + if key in _enabled_extensions ] else: extensions = list(extensions_map.values()) + collection_extensions = list(collections_extensions_map.values()) + if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): items_get_request_model = create_request_model( @@ -64,12 +80,19 @@ else: items_get_request_model = ItemCollectionUri -post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) -get_request_model = create_get_request_model(extensions) +if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): + collections_get_request_model = CollectionSearchExtension().GET +else: + collections_get_request_model = EmptyRequest + +post_request_model = create_post_request_model( + extensions + collection_extensions, base_model=PgstacSearch +) +get_request_model = create_get_request_model(extensions + collection_extensions) api = StacApi( settings=settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 648d42a1..4f6f6f11 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -14,12 +14,11 @@ from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate from stac_fastapi.api.models import JSONResponse -from stac_fastapi.types.core import AsyncBaseCoreClient +from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection -from stac_pydantic.links import Relations from stac_pydantic.shared import BBox, MimeTypes from stac_fastapi.pgstac.config import Settings @@ -39,17 +38,100 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - async def all_collections(self, request: Request, **kwargs) -> Collections: - """Read all collections from the database.""" + async def all_collections( # noqa: C901 + self, + request: Request, + # Extensions + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = None, + query: Optional[str] = None, + token: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /collections`. + + Returns: + Collections which match the search criteria, returns all + collections by default. + """ + + # Parse request parameters + base_args = { + "bbox": bbox, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) + + # Do the request + try: + search_request = self.post_request_model(**clean) + except ValidationError as e: + raise HTTPException( + status_code=400, detail=f"Invalid parameters provided {e}" + ) from e + + return await self._collection_search_base(search_request, request=request) + + async def _collection_search_base( # noqa: C901 + self, + search_request: PgstacSearch, + request: Request, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /search`. + + Args: + search_request: search request parameters. + + Returns: + All collections which match the search criteria. + """ base_url = get_base_url(request) + search_request_json = search_request.model_dump_json( + exclude_none=True, by_alias=True + ) + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=search_request_json, + ) + collections_result: Collections = await conn.fetchval(q, *p) + except InvalidDatetimeFormatError as e: + raise InvalidQueryParameter( + f"Datetime parameter {search_request.datetime} is invalid." + ) from e + + next: Optional[str] = None + prev: Optional[str] = None + + if links := collections_result.get("links"): + next = collections_result["links"].pop("next") + prev = collections_result["links"].pop("prev") - async with request.app.state.get_connection(request, "r") as conn: - collections = await conn.fetchval( - """ - SELECT * FROM all_collections(); - """ - ) linked_collections: List[Collection] = [] + collections = collections_result["collections"] if collections is not None and len(collections) > 0: for c in collections: coll = Collection(**c) @@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: linked_collections.append(coll) - links = [ - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.parent.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - ] - collection_list = Collections(collections=linked_collections or [], links=links) - return collection_list + links = await PagingLinks( + request=request, + next=next, + prev=prev, + ).get_links() + + return Collections( + collections=linked_collections or [], + links=links, + ) async def get_collection( self, collection_id: str, request: Request, **kwargs @@ -383,7 +456,7 @@ async def post_search( return ItemCollection(**item_collection) - async def get_search( # noqa: C901 + async def get_search( self, request: Request, collections: Optional[List[str]] = None, @@ -418,49 +491,15 @@ async def get_search( # noqa: C901 "query": orjson.loads(unquote_plus(query)) if query else query, } - if filter: - if filter_lang == "cql2-text": - ast = parse_cql2_text(filter) - base_args["filter"] = orjson.loads(to_cql2(ast)) - base_args["filter-lang"] = "cql2-json" - - if datetime: - base_args["datetime"] = format_datetime_range(datetime) - - if intersects: - base_args["intersects"] = orjson.loads(unquote_plus(intersects)) - - if sortby: - # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form - sort_param = [] - for sort in sortby: - sortparts = re.match(r"^([+-]?)(.*)$", sort) - if sortparts: - sort_param.append( - { - "field": sortparts.group(2).strip(), - "direction": "desc" if sortparts.group(1) == "-" else "asc", - } - ) - base_args["sortby"] = sort_param - - if fields: - includes = set() - excludes = set() - for field in fields: - if field[0] == "-": - excludes.add(field[1:]) - elif field[0] == "+": - includes.add(field[1:]) - else: - includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} - - # Remove None values from dict - clean = {} - for k, v in base_args.items(): - if v is not None and v != []: - clean[k] = v + clean = clean_search_args( + base_args=base_args, + intersects=intersects, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) # Do the request try: @@ -471,3 +510,60 @@ async def get_search( # noqa: C901 ) from e return await self.post_search(search_request, request=request) + + +def clean_search_args( # noqa: C901 + base_args: Dict[str, Any], + intersects: Optional[str] = None, + datetime: Optional[DateTimeType] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, +) -> Dict[str, Any]: + """Clean up search arguments to match format expected by pgstac""" + if filter: + if filter_lang == "cql2-text": + ast = parse_cql2_text(filter) + base_args["filter"] = orjson.loads(to_cql2(ast)) + base_args["filter-lang"] = "cql2-json" + + if datetime: + base_args["datetime"] = format_datetime_range(datetime) + + if intersects: + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) + + if sortby: + # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form + sort_param = [] + for sort in sortby: + sortparts = re.match(r"^([+-]?)(.*)$", sort) + if sortparts: + sort_param.append( + { + "field": sortparts.group(2).strip(), + "direction": "desc" if sortparts.group(1) == "-" else "asc", + } + ) + base_args["sortby"] = sort_param + + if fields: + includes = set() + excludes = set() + for field in fields: + if field[0] == "-": + excludes.add(field[1:]) + elif field[0] == "+": + includes.add(field[1:]) + else: + includes.add(field) + base_args["fields"] = {"include": includes, "exclude": excludes} + + # Remove None values from dict + clean = {} + for k, v in base_args.items(): + if v is not None and v != []: + clean[k] = v + + return clean diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 2077c352..fddc9a3b 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -502,6 +502,33 @@ async def test_collection_queryables(load_test_data, app_client, load_test_colle assert "id" in q["properties"] +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 + + @pytest.mark.asyncio async def test_item_collection_filter_bbox( load_test_data, app_client, load_test_collection diff --git a/tests/conftest.py b/tests/conftest.py index e5955b75..8f9480cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -133,6 +134,7 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] + collection_extensions = [CollectionSearchExtension()] items_get_request_model = create_request_model( model_name="ItemCollectionUri", @@ -149,11 +151,12 @@ def api_client(request, database): ) api = StacApi( settings=api_settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, + collections_get_request_model=CollectionSearchExtension().GET, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), ) diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 3a2183b1..634747bc 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -276,3 +276,30 @@ async def test_get_collections_queryables_links(app_client, load_test_collection f"/collections/{collection_id}", ) assert "Queryables" in [link.get("title") for link in resp.json()["links"]] + + +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2