From 9a3797a149f17491735a6134109363f681d934c5 Mon Sep 17 00:00:00 2001 From: Henry Rodman Date: Wed, 2 Oct 2024 09:23:56 -0500 Subject: [PATCH 1/2] add collection search extension (#136) * add collection-search extension * define collections_get_request_model * add test for additional extensions with collection-search * pass collections_get_request_model to StacApi * do not pass collection extensions to post_request_model * do not pass collection extensions to get_request_model * Do not add extensions to collection-search extension * use CollectionSearchExtension.from_extensions() * keep extensions and collection_search_extension separate * update tests * filter -> filter_query * add collection_get_request_model to client * add collection_request_model to the client * recycle collections_get_request_model in client * drop print statement * simplify * remove unused * clean up control flow for extension-specific logic * add link to PR in changelog --------- Co-authored-by: vincentsarago --- .dockerignore | 4 +- .github/workflows/cicd.yaml | 2 +- CHANGES.md | 3 +- setup.py | 6 +- stac_fastapi/pgstac/app.py | 39 ++++-- stac_fastapi/pgstac/core.py | 207 +++++++++++++++++++---------- tests/api/test_api.py | 64 ++++++++- tests/conftest.py | 8 +- tests/resources/test_collection.py | 27 ++++ 9 files changed, 269 insertions(+), 91 deletions(-) diff --git a/.dockerignore b/.dockerignore index e187e45a..ddfce7a0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,5 +9,7 @@ coverage.xml *.log .git .envrc +*egg-info -venv \ No newline at end of file +venv +env diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 2618036c..d4a03663 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -47,7 +47,7 @@ jobs: runs-on: ubuntu-latest services: pgstac: - image: ghcr.io/stac-utils/pgstac:v0.7.10 + image: ghcr.io/stac-utils/pgstac:v0.8.6 env: POSTGRES_USER: username POSTGRES_PASSWORD: password diff --git a/CHANGES.md b/CHANGES.md index 66031dde..9aebcfc4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,7 +2,8 @@ ## [Unreleased] -- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, https://github.com/stac-utils/stac-fastapi-pgstac/pull/142) +- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, ) +- Add collection search extension ([#139](https://github.com/stac-utils/stac-fastapi-pgstac/pull/139)) - Fix `filter` extension implementation in `CoreCrudClient` diff --git a/setup.py b/setup.py index 76b52893..74b63833 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=3.0", - "stac-fastapi.extensions~=3.0", - "stac-fastapi.types~=3.0", + "stac-fastapi.api~=3.0.2", + "stac-fastapi.extensions~=3.0.2", + "stac-fastapi.types~=3.0.2", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 924ea01f..5e28cd09 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -10,6 +10,7 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import ( + EmptyRequest, ItemCollectionUri, create_get_request_model, create_post_request_model, @@ -22,6 +23,7 @@ TokenPaginationExtension, TransactionExtension, ) +from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings @@ -47,34 +49,49 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } -if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): - extensions = [ - extensions_map[extension_name] for extension_name in enabled_extensions.split(",") - ] -else: - extensions = list(extensions_map.values()) +enabled_extensions = ( + os.environ["ENABLED_EXTENSIONS"].split(",") + if "ENABLED_EXTENSIONS" in os.environ + else list(extensions_map.keys()) + ["collection_search"] +) +extensions = [ + extension for key, extension in extensions_map.items() if key in enabled_extensions +] -if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): - items_get_request_model = create_request_model( +items_get_request_model = ( + create_request_model( model_name="ItemCollectionUri", base_model=ItemCollectionUri, mixins=[TokenPaginationExtension().GET], request_type="GET", ) -else: - items_get_request_model = ItemCollectionUri + if any(isinstance(ext, TokenPaginationExtension) for ext in extensions) + else ItemCollectionUri +) + +collection_search_extension = ( + CollectionSearchExtension.from_extensions(extensions) + if "collection_search" in enabled_extensions + else None +) +collections_get_request_model = ( + collection_search_extension.GET if collection_search_extension else EmptyRequest +) post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) api = StacApi( settings=settings, - extensions=extensions, + extensions=extensions + [collection_search_extension] + if collection_search_extension + else extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, search_post_request_model=post_request_model, + collections_get_request_model=collections_get_request_model, ) app = api.app diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7a39b50a..e7dcab21 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -1,5 +1,6 @@ """Item crud client.""" +import json import re from typing import Any, Dict, List, Optional, Set, Union from urllib.parse import unquote_plus, urljoin @@ -14,12 +15,11 @@ from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate from stac_fastapi.api.models import JSONResponse -from stac_fastapi.types.core import AsyncBaseCoreClient +from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection -from stac_pydantic.links import Relations from stac_pydantic.shared import BBox, MimeTypes from stac_fastapi.pgstac.config import Settings @@ -39,17 +39,66 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - async def all_collections(self, request: Request, **kwargs) -> Collections: - """Read all collections from the database.""" + async def all_collections( # noqa: C901 + self, + request: Request, + # Extensions + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = None, + query: Optional[str] = None, + token: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /collections`. + + Returns: + Collections which match the search criteria, returns all + collections by default. + """ base_url = get_base_url(request) + # Parse request parameters + base_args = { + "bbox": bbox, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean_args = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter_query=filter, + filter_lang=filter_lang, + ) + async with request.app.state.get_connection(request, "r") as conn: - collections = await conn.fetchval( - """ - SELECT * FROM all_collections(); + q, p = render( """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=json.dumps(clean_args), ) + collections_result: Collections = await conn.fetchval(q, *p) + + next: Optional[str] = None + prev: Optional[str] = None + + if links := collections_result.get("links"): + next = collections_result["links"].pop("next") + prev = collections_result["links"].pop("prev") + linked_collections: List[Collection] = [] + collections = collections_result["collections"] if collections is not None and len(collections) > 0: for c in collections: coll = Collection(**c) @@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: linked_collections.append(coll) - links = [ - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.parent.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - ] - collection_list = Collections(collections=linked_collections or [], links=links) - return collection_list + links = await PagingLinks( + request=request, + next=next, + prev=prev, + ).get_links() + + return Collections( + collections=linked_collections or [], + links=links, + ) async def get_collection( self, collection_id: str, request: Request, **kwargs @@ -386,7 +426,7 @@ async def post_search( return ItemCollection(**item_collection) - async def get_search( # noqa: C901 + async def get_search( self, request: Request, collections: Optional[List[str]] = None, @@ -421,51 +461,15 @@ async def get_search( # noqa: C901 "query": orjson.loads(unquote_plus(query)) if query else query, } - if filter: - if filter_lang == "cql2-text": - filter = to_cql2(parse_cql2_text(filter)) - filter_lang = "cql2-json" - - base_args["filter"] = orjson.loads(filter) - base_args["filter-lang"] = filter_lang - - if datetime: - base_args["datetime"] = format_datetime_range(datetime) - - if intersects: - base_args["intersects"] = orjson.loads(unquote_plus(intersects)) - - if sortby: - # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form - sort_param = [] - for sort in sortby: - sortparts = re.match(r"^([+-]?)(.*)$", sort) - if sortparts: - sort_param.append( - { - "field": sortparts.group(2).strip(), - "direction": "desc" if sortparts.group(1) == "-" else "asc", - } - ) - base_args["sortby"] = sort_param - - if fields: - includes = set() - excludes = set() - for field in fields: - if field[0] == "-": - excludes.add(field[1:]) - elif field[0] == "+": - includes.add(field[1:]) - else: - includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} - - # Remove None values from dict - clean = {} - for k, v in base_args.items(): - if v is not None and v != []: - clean[k] = v + clean = clean_search_args( + base_args=base_args, + intersects=intersects, + datetime=datetime, + fields=fields, + sortby=sortby, + filter_query=filter, + filter_lang=filter_lang, + ) # Do the request try: @@ -476,3 +480,62 @@ async def get_search( # noqa: C901 ) from e return await self.post_search(search_request, request=request) + + +def clean_search_args( # noqa: C901 + base_args: Dict[str, Any], + intersects: Optional[str] = None, + datetime: Optional[DateTimeType] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter_query: Optional[str] = None, + filter_lang: Optional[str] = None, +) -> Dict[str, Any]: + """Clean up search arguments to match format expected by pgstac""" + if filter_query: + if filter_lang == "cql2-text": + filter_query = to_cql2(parse_cql2_text(filter_query)) + filter_lang = "cql2-json" + + base_args["filter"] = orjson.loads(filter_query) + base_args["filter_lang"] = filter_lang + + if datetime: + base_args["datetime"] = format_datetime_range(datetime) + + if intersects: + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) + + if sortby: + # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form + sort_param = [] + for sort in sortby: + sortparts = re.match(r"^([+-]?)(.*)$", sort) + if sortparts: + sort_param.append( + { + "field": sortparts.group(2).strip(), + "direction": "desc" if sortparts.group(1) == "-" else "asc", + } + ) + base_args["sortby"] = sort_param + + if fields: + includes = set() + excludes = set() + for field in fields: + if field[0] == "-": + excludes.add(field[1:]) + elif field[0] == "+": + includes.add(field[1:]) + else: + includes.add(field) + base_args["fields"] = {"include": includes, "exclude": excludes} + + # Remove None values from dict + clean = {} + for k, v in base_args.items(): + if v is not None and v != []: + clean[k] = v + + return clean diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 2077c352..b135dce5 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -12,7 +12,11 @@ from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import create_get_request_model, create_post_request_model -from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension +from stac_fastapi.extensions.core import ( + CollectionSearchExtension, + FieldsExtension, + TransactionExtension, +) from stac_fastapi.types import stac as stac_types from stac_fastapi.pgstac.core import CoreCrudClient, Settings @@ -502,6 +506,49 @@ async def test_collection_queryables(load_test_data, app_client, load_test_colle assert "id" in q["properties"] +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 + + # this search should return test collection 1 first + resp = await app_client.get( + "/collections", + params={"sortby": "title"}, + ) + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + assert resp.json()["collections"][1]["id"] == load_test2_collection.id + + # this search should return test collection 2 first + resp = await app_client.get( + "/collections", + params={"sortby": "-title"}, + ) + assert resp.json()["collections"][1]["id"] == load_test_collection["id"] + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + @pytest.mark.asyncio async def test_item_collection_filter_bbox( load_test_data, app_client, load_test_collection @@ -683,12 +730,18 @@ async def get_collection( ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) + + collection_search_extension = CollectionSearchExtension.from_extensions( + extensions=extensions + ) + api = StacApi( client=Client(post_request_model=post_request_model), settings=settings, extensions=extensions, search_post_request_model=post_request_model, search_get_request_model=get_request_model, + collections_get_request_model=collection_search_extension.GET, ) app = api.app await connect_to_db(app) @@ -760,6 +813,15 @@ async def test_no_extension( collections = await client.get("http://test/collections") assert collections.status_code == 200, collections.text + # datetime should be ignored + collection_datetime = await client.get( + "http://test/collections/test-collection", + params={ + "datetime": "2000-01-01T00:00:00Z/2000-12-31T00:00:00Z", + }, + ) + assert collection_datetime.text == collection.text + item = await client.get( "http://test/collections/test-collection/items/test-item" ) diff --git a/tests/conftest.py b/tests/conftest.py index e5955b75..fc63514f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -133,6 +134,7 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] + collection_search_extension = CollectionSearchExtension.from_extensions(extensions) items_get_request_model = create_request_model( model_name="ItemCollectionUri", @@ -147,13 +149,17 @@ def api_client(request, database): search_post_request_model = create_post_request_model( extensions, base_model=PgstacSearch ) + + collections_get_request_model = collection_search_extension.GET + api = StacApi( settings=api_settings, - extensions=extensions, + extensions=extensions + [collection_search_extension], client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, + collections_get_request_model=collections_get_request_model, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), ) diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 3a2183b1..634747bc 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -276,3 +276,30 @@ async def test_get_collections_queryables_links(app_client, load_test_collection f"/collections/{collection_id}", ) assert "Queryables" in [link.get("title") for link in resp.json()["links"]] + + +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 From 43214976f6cdedb5b889d22f81c622c9b3a3cc0b Mon Sep 17 00:00:00 2001 From: Henry Rodman Date: Wed, 9 Oct 2024 08:52:48 -0500 Subject: [PATCH 2/2] create separate list of collection search extensions (#158) * create separate list of collection search extensions * update changelog * specify collection_extensions in conftest.py * remove last warning --------- Co-authored-by: vincentsarago --- CHANGES.md | 2 +- stac_fastapi/pgstac/app.py | 17 ++++++++++++++++- tests/api/test_api.py | 4 +++- tests/conftest.py | 11 ++++++++++- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 9aebcfc4..ba739ac7 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -4,7 +4,7 @@ - Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, ) - Add collection search extension ([#139](https://github.com/stac-utils/stac-fastapi-pgstac/pull/139)) - +- keep `/search` and `/collections` extensions separate ([#158](https://github.com/stac-utils/stac-fastapi-pgstac/pull/158)) - Fix `filter` extension implementation in `CoreCrudClient` ## [3.0.0] - 2024-08-02 diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 5e28cd09..9ba27e08 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -49,6 +49,14 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } +# some extensions are supported in combination with the collection search extension +collection_extensions_map = { + "query": QueryExtension(), + "sort": SortExtension(), + "fields": FieldsExtension(), + "filter": FilterExtension(client=FiltersClient()), +} + enabled_extensions = ( os.environ["ENABLED_EXTENSIONS"].split(",") if "ENABLED_EXTENSIONS" in os.environ @@ -70,10 +78,17 @@ ) collection_search_extension = ( - CollectionSearchExtension.from_extensions(extensions) + CollectionSearchExtension.from_extensions( + [ + extension + for key, extension in collection_extensions_map.items() + if key in enabled_extensions + ] + ) if "collection_search" in enabled_extensions else None ) + collections_get_request_model = ( collection_search_extension.GET if collection_search_extension else EmptyRequest ) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index b135dce5..34c75f0e 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -732,7 +732,9 @@ async def get_collection( get_request_model = create_get_request_model(extensions) collection_search_extension = CollectionSearchExtension.from_extensions( - extensions=extensions + extensions=[ + FieldsExtension(), + ] ) api = StacApi( diff --git a/tests/conftest.py b/tests/conftest.py index fc63514f..e571cae6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,7 +134,16 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] - collection_search_extension = CollectionSearchExtension.from_extensions(extensions) + + collection_extensions = [ + QueryExtension(), + SortExtension(), + FieldsExtension(), + FilterExtension(client=FiltersClient()), + ] + collection_search_extension = CollectionSearchExtension.from_extensions( + collection_extensions + ) items_get_request_model = create_request_model( model_name="ItemCollectionUri",