From 930765beca6ed8ad1bd7a0db6d9486598578b9cd Mon Sep 17 00:00:00 2001 From: hrodmn Date: Fri, 16 Aug 2024 09:22:17 -0500 Subject: [PATCH 01/19] add collection-search extension --- .dockerignore | 4 +- .github/workflows/cicd.yaml | 2 +- CHANGES.md | 3 +- stac_fastapi/pgstac/app.py | 31 +++- stac_fastapi/pgstac/core.py | 243 ++++++++++++++++++++--------- tests/api/test_api.py | 27 ++++ tests/conftest.py | 5 +- tests/resources/test_collection.py | 27 ++++ 8 files changed, 259 insertions(+), 83 deletions(-) diff --git a/.dockerignore b/.dockerignore index e187e45a..ddfce7a0 100644 --- a/.dockerignore +++ b/.dockerignore @@ -9,5 +9,7 @@ coverage.xml *.log .git .envrc +*egg-info -venv \ No newline at end of file +venv +env diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 2618036c..d4a03663 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -47,7 +47,7 @@ jobs: runs-on: ubuntu-latest services: pgstac: - image: ghcr.io/stac-utils/pgstac:v0.7.10 + image: ghcr.io/stac-utils/pgstac:v0.8.6 env: POSTGRES_USER: username POSTGRES_PASSWORD: password diff --git a/CHANGES.md b/CHANGES.md index 66031dde..1cfbb637 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,7 +2,8 @@ ## [Unreleased] -- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, https://github.com/stac-utils/stac-fastapi-pgstac/pull/142) +- Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, ) +- Add collection search extension - Fix `filter` extension implementation in `CoreCrudClient` diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 924ea01f..10b05377 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -10,12 +10,14 @@ from fastapi.responses import ORJSONResponse from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import ( + EmptyRequest, ItemCollectionUri, create_get_request_model, create_post_request_model, create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -47,12 +49,26 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } +collections_extensions_map = { + "collection_search": CollectionSearchExtension(), +} + if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): + _enabled_extensions = enabled_extensions.split(",") extensions = [ - extensions_map[extension_name] for extension_name in enabled_extensions.split(",") + extension + for key, extension in extensions_map.items() + if key in _enabled_extensions + ] + collection_extensions = [ + extension + for key, extension in collections_extensions_map.items() + if key in _enabled_extensions ] else: extensions = list(extensions_map.values()) + collection_extensions = list(collections_extensions_map.values()) + if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): items_get_request_model = create_request_model( @@ -64,12 +80,19 @@ else: items_get_request_model = ItemCollectionUri -post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) -get_request_model = create_get_request_model(extensions) +if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): + collections_get_request_model = CollectionSearchExtension().GET +else: + collections_get_request_model = EmptyRequest + +post_request_model = create_post_request_model( + extensions + collection_extensions, base_model=PgstacSearch +) +get_request_model = create_get_request_model(extensions + collection_extensions) api = StacApi( settings=settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7a39b50a..60fb3387 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -14,12 +14,11 @@ from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate from stac_fastapi.api.models import JSONResponse -from stac_fastapi.types.core import AsyncBaseCoreClient +from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url from stac_fastapi.types.rfc3339 import DateTimeType from stac_fastapi.types.stac import Collection, Collections, Item, ItemCollection -from stac_pydantic.links import Relations from stac_pydantic.shared import BBox, MimeTypes from stac_fastapi.pgstac.config import Settings @@ -39,17 +38,100 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - async def all_collections(self, request: Request, **kwargs) -> Collections: - """Read all collections from the database.""" + async def all_collections( # noqa: C901 + self, + request: Request, + # Extensions + bbox: Optional[BBox] = None, + datetime: Optional[DateTimeType] = None, + limit: Optional[int] = None, + query: Optional[str] = None, + token: Optional[str] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, + **kwargs, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /collections`. + + Returns: + Collections which match the search criteria, returns all + collections by default. + """ + + # Parse request parameters + base_args = { + "bbox": bbox, + "limit": limit, + "token": token, + "query": orjson.loads(unquote_plus(query)) if query else query, + } + + clean = clean_search_args( + base_args=base_args, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) + + # Do the request + try: + search_request = self.post_request_model(**clean) + except ValidationError as e: + raise HTTPException( + status_code=400, detail=f"Invalid parameters provided {e}" + ) from e + + return await self._collection_search_base(search_request, request=request) + + async def _collection_search_base( # noqa: C901 + self, + search_request: PgstacSearch, + request: Request, + ) -> Collections: + """Cross catalog search (GET). + + Called with `GET /search`. + + Args: + search_request: search request parameters. + + Returns: + All collections which match the search criteria. + """ base_url = get_base_url(request) + search_request_json = search_request.model_dump_json( + exclude_none=True, by_alias=True + ) + + try: + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=search_request_json, + ) + collections_result: Collections = await conn.fetchval(q, *p) + except InvalidDatetimeFormatError as e: + raise InvalidQueryParameter( + f"Datetime parameter {search_request.datetime} is invalid." + ) from e + + next: Optional[str] = None + prev: Optional[str] = None + + if links := collections_result.get("links"): + next = collections_result["links"].pop("next") + prev = collections_result["links"].pop("prev") - async with request.app.state.get_connection(request, "r") as conn: - collections = await conn.fetchval( - """ - SELECT * FROM all_collections(); - """ - ) linked_collections: List[Collection] = [] + collections = collections_result["collections"] if collections is not None and len(collections) > 0: for c in collections: coll = Collection(**c) @@ -71,25 +153,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections: linked_collections.append(coll) - links = [ - { - "rel": Relations.root.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.parent.value, - "type": MimeTypes.json, - "href": base_url, - }, - { - "rel": Relations.self.value, - "type": MimeTypes.json, - "href": urljoin(base_url, "collections"), - }, - ] - collection_list = Collections(collections=linked_collections or [], links=links) - return collection_list + links = await PagingLinks( + request=request, + next=next, + prev=prev, + ).get_links() + + return Collections( + collections=linked_collections or [], + links=links, + ) async def get_collection( self, collection_id: str, request: Request, **kwargs @@ -386,7 +459,7 @@ async def post_search( return ItemCollection(**item_collection) - async def get_search( # noqa: C901 + async def get_search( self, request: Request, collections: Optional[List[str]] = None, @@ -421,51 +494,15 @@ async def get_search( # noqa: C901 "query": orjson.loads(unquote_plus(query)) if query else query, } - if filter: - if filter_lang == "cql2-text": - filter = to_cql2(parse_cql2_text(filter)) - filter_lang = "cql2-json" - - base_args["filter"] = orjson.loads(filter) - base_args["filter-lang"] = filter_lang - - if datetime: - base_args["datetime"] = format_datetime_range(datetime) - - if intersects: - base_args["intersects"] = orjson.loads(unquote_plus(intersects)) - - if sortby: - # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form - sort_param = [] - for sort in sortby: - sortparts = re.match(r"^([+-]?)(.*)$", sort) - if sortparts: - sort_param.append( - { - "field": sortparts.group(2).strip(), - "direction": "desc" if sortparts.group(1) == "-" else "asc", - } - ) - base_args["sortby"] = sort_param - - if fields: - includes = set() - excludes = set() - for field in fields: - if field[0] == "-": - excludes.add(field[1:]) - elif field[0] == "+": - includes.add(field[1:]) - else: - includes.add(field) - base_args["fields"] = {"include": includes, "exclude": excludes} - - # Remove None values from dict - clean = {} - for k, v in base_args.items(): - if v is not None and v != []: - clean[k] = v + clean = clean_search_args( + base_args=base_args, + intersects=intersects, + datetime=datetime, + fields=fields, + sortby=sortby, + filter=filter, + filter_lang=filter_lang, + ) # Do the request try: @@ -476,3 +513,59 @@ async def get_search( # noqa: C901 ) from e return await self.post_search(search_request, request=request) + + +def clean_search_args( # noqa: C901 + base_args: Dict[str, Any], + intersects: Optional[str] = None, + datetime: Optional[DateTimeType] = None, + fields: Optional[List[str]] = None, + sortby: Optional[str] = None, + filter: Optional[str] = None, + filter_lang: Optional[str] = None, +) -> Dict[str, Any]: + """Clean up search arguments to match format expected by pgstac""" + if filter: + if filter_lang == "cql2-text": + filter = to_cql2(parse_cql2_text(filter)) + filter_lang = "cql2-json" + + if datetime: + base_args["datetime"] = format_datetime_range(datetime) + + if intersects: + base_args["intersects"] = orjson.loads(unquote_plus(intersects)) + + if sortby: + # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form + sort_param = [] + for sort in sortby: + sortparts = re.match(r"^([+-]?)(.*)$", sort) + if sortparts: + sort_param.append( + { + "field": sortparts.group(2).strip(), + "direction": "desc" if sortparts.group(1) == "-" else "asc", + } + ) + base_args["sortby"] = sort_param + + if fields: + includes = set() + excludes = set() + for field in fields: + if field[0] == "-": + excludes.add(field[1:]) + elif field[0] == "+": + includes.add(field[1:]) + else: + includes.add(field) + base_args["fields"] = {"include": includes, "exclude": excludes} + + # Remove None values from dict + clean = {} + for k, v in base_args.items(): + if v is not None and v != []: + clean[k] = v + + return clean diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 2077c352..fddc9a3b 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -502,6 +502,33 @@ async def test_collection_queryables(load_test_data, app_client, load_test_colle assert "id" in q["properties"] +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 + + @pytest.mark.asyncio async def test_item_collection_filter_bbox( load_test_data, app_client, load_test_collection diff --git a/tests/conftest.py b/tests/conftest.py index e5955b75..8f9480cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,7 @@ create_request_model, ) from stac_fastapi.extensions.core import ( + CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, @@ -133,6 +134,7 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] + collection_extensions = [CollectionSearchExtension()] items_get_request_model = create_request_model( model_name="ItemCollectionUri", @@ -149,11 +151,12 @@ def api_client(request, database): ) api = StacApi( settings=api_settings, - extensions=extensions, + extensions=extensions + collection_extensions, client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, + collections_get_request_model=CollectionSearchExtension().GET, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), ) diff --git a/tests/resources/test_collection.py b/tests/resources/test_collection.py index 3a2183b1..634747bc 100644 --- a/tests/resources/test_collection.py +++ b/tests/resources/test_collection.py @@ -276,3 +276,30 @@ async def test_get_collections_queryables_links(app_client, load_test_collection f"/collections/{collection_id}", ) assert "Queryables" in [link.get("title") for link in resp.json()["links"]] + + +@pytest.mark.asyncio +async def test_get_collections_search( + app_client, load_test_collection, load_test2_collection +): + # this search should only return a single collection + resp = await app_client.get( + "/collections", + params={"datetime": "2010-01-01T00:00:00Z/2010-01-02T00:00:00Z"}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + + # same with this one + resp = await app_client.get( + "/collections", + params={"datetime": "2020-01-01T00:00:00Z/.."}, + ) + assert len(resp.json()["collections"]) == 1 + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + + # no params should return both collections + resp = await app_client.get( + "/collections", + ) + assert len(resp.json()["collections"]) == 2 From d06480e88c3cac1a24b517947547d49f2e36f9be Mon Sep 17 00:00:00 2001 From: hrodmn Date: Fri, 16 Aug 2024 10:12:28 -0500 Subject: [PATCH 02/19] define collections_get_request_model --- stac_fastapi/pgstac/app.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 10b05377..3eb3e135 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -81,7 +81,9 @@ items_get_request_model = ItemCollectionUri if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): - collections_get_request_model = CollectionSearchExtension().GET + collections_get_request_model = create_get_request_model( + extensions + collection_extensions + ) else: collections_get_request_model = EmptyRequest @@ -98,6 +100,7 @@ items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, search_post_request_model=post_request_model, + collections_get_request_model=collections_get_request_model, ) app = api.app From 6ce1ea776324392de42b59bce1d5b0a19d96a25f Mon Sep 17 00:00:00 2001 From: hrodmn Date: Fri, 16 Aug 2024 10:53:32 -0500 Subject: [PATCH 03/19] add test for additional extensions with collection-search --- tests/api/test_api.py | 16 ++++++++++++++++ tests/conftest.py | 7 ++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index fddc9a3b..1f0f6f1f 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -528,6 +528,22 @@ async def test_get_collections_search( ) assert len(resp.json()["collections"]) == 2 + # this search should return test collection 1 first + resp = await app_client.get( + "/collections", + params={"sortby": "title"}, + ) + assert resp.json()["collections"][0]["id"] == load_test_collection["id"] + assert resp.json()["collections"][1]["id"] == load_test2_collection.id + + # this search should return test collection 2 first + resp = await app_client.get( + "/collections", + params={"sortby": "-title"}, + ) + assert resp.json()["collections"][1]["id"] == load_test_collection["id"] + assert resp.json()["collections"][0]["id"] == load_test2_collection.id + @pytest.mark.asyncio async def test_item_collection_filter_bbox( diff --git a/tests/conftest.py b/tests/conftest.py index 8f9480cc..4ca5c73c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -149,6 +149,11 @@ def api_client(request, database): search_post_request_model = create_post_request_model( extensions, base_model=PgstacSearch ) + + collections_get_request_model = create_get_request_model( + extensions + collection_extensions + ) + api = StacApi( settings=api_settings, extensions=extensions + collection_extensions, @@ -156,7 +161,7 @@ def api_client(request, database): items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, - collections_get_request_model=CollectionSearchExtension().GET, + collections_get_request_model=collections_get_request_model, response_class=ORJSONResponse, router=APIRouter(prefix=prefix), ) From c597c697ebf69da20116b1f11223f5aac100bbae Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 19 Aug 2024 10:23:18 -0500 Subject: [PATCH 04/19] pass collections_get_request_model to StacApi --- stac_fastapi/pgstac/app.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 3eb3e135..a9be244c 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -24,6 +24,9 @@ TokenPaginationExtension, TransactionExtension, ) +from stac_fastapi.extensions.core.collection_search.request import ( + BaseCollectionSearchGetRequest, +) from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings @@ -81,8 +84,11 @@ items_get_request_model = ItemCollectionUri if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): - collections_get_request_model = create_get_request_model( - extensions + collection_extensions + collections_get_request_model = create_request_model( + model_name="CollectionsGetRequest", + base_model=BaseCollectionSearchGetRequest, + extensions=extensions, + request_type="GET", ) else: collections_get_request_model = EmptyRequest From b3065aba07b0e5efbcc036abf6b07c55173da140 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 19 Aug 2024 10:29:36 -0500 Subject: [PATCH 05/19] do not pass collection extensions to post_request_model --- stac_fastapi/pgstac/app.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index a9be244c..5c58a826 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -93,9 +93,7 @@ else: collections_get_request_model = EmptyRequest -post_request_model = create_post_request_model( - extensions + collection_extensions, base_model=PgstacSearch -) +post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions + collection_extensions) api = StacApi( From 90ef6d41b0f1564a0097a155a88f1eb2e0bbf2dc Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 19 Aug 2024 10:33:12 -0500 Subject: [PATCH 06/19] do not pass collection extensions to get_request_model --- stac_fastapi/pgstac/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 5c58a826..45f5a9c7 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -94,7 +94,7 @@ collections_get_request_model = EmptyRequest post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) -get_request_model = create_get_request_model(extensions + collection_extensions) +get_request_model = create_get_request_model(extensions) api = StacApi( settings=settings, From f560aeccceb72dd743b5887e02e44c4f7a3fd304 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Tue, 20 Aug 2024 21:06:57 -0500 Subject: [PATCH 07/19] Do not add extensions to collection-search extension --- stac_fastapi/pgstac/app.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 45f5a9c7..2c8d79bd 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -17,16 +17,13 @@ create_request_model, ) from stac_fastapi.extensions.core import ( - CollectionSearchExtension, FieldsExtension, FilterExtension, SortExtension, TokenPaginationExtension, TransactionExtension, ) -from stac_fastapi.extensions.core.collection_search.request import ( - BaseCollectionSearchGetRequest, -) +from stac_fastapi.extensions.core.collection_search import CollectionSearchExtension from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings @@ -52,10 +49,6 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } -collections_extensions_map = { - "collection_search": CollectionSearchExtension(), -} - if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): _enabled_extensions = enabled_extensions.split(",") extensions = [ @@ -63,14 +56,9 @@ for key, extension in extensions_map.items() if key in _enabled_extensions ] - collection_extensions = [ - extension - for key, extension in collections_extensions_map.items() - if key in _enabled_extensions - ] else: + _enabled_extensions = list(extensions_map.keys()) + ["collection_search"] extensions = list(extensions_map.values()) - collection_extensions = list(collections_extensions_map.values()) if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): @@ -83,13 +71,10 @@ else: items_get_request_model = ItemCollectionUri -if any(isinstance(ext, CollectionSearchExtension) for ext in collection_extensions): - collections_get_request_model = create_request_model( - model_name="CollectionsGetRequest", - base_model=BaseCollectionSearchGetRequest, - extensions=extensions, - request_type="GET", - ) +if "collection_search" in _enabled_extensions: + collection_extension = CollectionSearchExtension() + collections_get_request_model = collection_extension.GET + extensions.append(collection_extension) else: collections_get_request_model = EmptyRequest @@ -98,7 +83,7 @@ api = StacApi( settings=settings, - extensions=extensions + collection_extensions, + extensions=extensions, client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, From c6b66c56cad01f2b0f732584013694cea2db5c0d Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 23 Sep 2024 06:25:14 -0500 Subject: [PATCH 08/19] use CollectionSearchExtension.from_extensions() --- setup.py | 6 +++--- stac_fastapi/pgstac/app.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 76b52893..74b63833 100644 --- a/setup.py +++ b/setup.py @@ -10,9 +10,9 @@ "orjson", "pydantic", "stac_pydantic==3.1.*", - "stac-fastapi.api~=3.0", - "stac-fastapi.extensions~=3.0", - "stac-fastapi.types~=3.0", + "stac-fastapi.api~=3.0.2", + "stac-fastapi.extensions~=3.0.2", + "stac-fastapi.types~=3.0.2", "asyncpg", "buildpg", "brotli_asgi", diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 2c8d79bd..6f641500 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -72,7 +72,9 @@ items_get_request_model = ItemCollectionUri if "collection_search" in _enabled_extensions: - collection_extension = CollectionSearchExtension() + collection_extension = CollectionSearchExtension.from_extensions( + extensions=extensions + ) collections_get_request_model = collection_extension.GET extensions.append(collection_extension) else: From 54962c4e4d956f7049dec4d530a81b904dd6c115 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 23 Sep 2024 06:56:07 -0500 Subject: [PATCH 09/19] keep extensions and collection_search_extension separate --- stac_fastapi/pgstac/app.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 6f641500..85939903 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -72,12 +72,12 @@ items_get_request_model = ItemCollectionUri if "collection_search" in _enabled_extensions: - collection_extension = CollectionSearchExtension.from_extensions( + collection_search_extension = CollectionSearchExtension.from_extensions( extensions=extensions ) - collections_get_request_model = collection_extension.GET - extensions.append(collection_extension) + collections_get_request_model = collection_search_extension.GET else: + collection_search_extension = None collections_get_request_model = EmptyRequest post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) @@ -85,7 +85,7 @@ api = StacApi( settings=settings, - extensions=extensions, + extensions=extensions + [collection_search_extension], client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, From de5c1a4ab6af6d187db9293dacfedb8590232b44 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Mon, 23 Sep 2024 06:56:38 -0500 Subject: [PATCH 10/19] update tests --- tests/api/test_api.py | 9 +++++++++ tests/conftest.py | 8 +++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 1f0f6f1f..79eecbb6 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -803,6 +803,15 @@ async def test_no_extension( collections = await client.get("http://test/collections") assert collections.status_code == 200, collections.text + # datetime should be ignored + collection_datetime = await client.get( + "http://test/collections/test-collection", + params={ + "datetime": "2000-01-01T00:00:00Z/2000-12-31T00:00:00Z", + }, + ) + assert collection_datetime.text == collection.text + item = await client.get( "http://test/collections/test-collection/items/test-item" ) diff --git a/tests/conftest.py b/tests/conftest.py index 4ca5c73c..fc63514f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -134,7 +134,7 @@ def api_client(request, database): FilterExtension(client=FiltersClient()), BulkTransactionExtension(client=BulkTransactionsClient()), ] - collection_extensions = [CollectionSearchExtension()] + collection_search_extension = CollectionSearchExtension.from_extensions(extensions) items_get_request_model = create_request_model( model_name="ItemCollectionUri", @@ -150,13 +150,11 @@ def api_client(request, database): extensions, base_model=PgstacSearch ) - collections_get_request_model = create_get_request_model( - extensions + collection_extensions - ) + collections_get_request_model = collection_search_extension.GET api = StacApi( settings=api_settings, - extensions=extensions + collection_extensions, + extensions=extensions + [collection_search_extension], client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, From 1dfb484fd7fcb883fceef3083267ae2350496d61 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 25 Sep 2024 05:47:17 -0500 Subject: [PATCH 11/19] filter -> filter_query --- stac_fastapi/pgstac/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 60fb3387..ab676c66 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -75,7 +75,7 @@ async def all_collections( # noqa: C901 datetime=datetime, fields=fields, sortby=sortby, - filter=filter, + filter_query=filter, filter_lang=filter_lang, ) @@ -500,7 +500,7 @@ async def get_search( datetime=datetime, fields=fields, sortby=sortby, - filter=filter, + filter_query=filter, filter_lang=filter_lang, ) @@ -521,13 +521,13 @@ def clean_search_args( # noqa: C901 datetime: Optional[DateTimeType] = None, fields: Optional[List[str]] = None, sortby: Optional[str] = None, - filter: Optional[str] = None, + filter_query: Optional[str] = None, filter_lang: Optional[str] = None, ) -> Dict[str, Any]: """Clean up search arguments to match format expected by pgstac""" - if filter: + if filter_query: if filter_lang == "cql2-text": - filter = to_cql2(parse_cql2_text(filter)) + filter_query = to_cql2(parse_cql2_text(filter_query)) filter_lang = "cql2-json" if datetime: From fd5c48b3e3988e3dc726c6448d6411045f33fe27 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 25 Sep 2024 05:58:16 -0500 Subject: [PATCH 12/19] add collection_get_request_model to client --- stac_fastapi/pgstac/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index ab676c66..e0b47cfa 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -38,6 +38,8 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" + collection_request_model = attr.ib(default=PgstacSearch) + async def all_collections( # noqa: C901 self, request: Request, @@ -81,7 +83,7 @@ async def all_collections( # noqa: C901 # Do the request try: - search_request = self.post_request_model(**clean) + search_request = self.collection_request_model(**clean) except ValidationError as e: raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" From bfd94f452d8f3ef8033ec88dafcfbe731e177c08 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 25 Sep 2024 06:48:44 -0500 Subject: [PATCH 13/19] add collection_request_model to the client --- stac_fastapi/pgstac/app.py | 8 +++++++- stac_fastapi/pgstac/core.py | 3 +++ tests/api/test_api.py | 19 +++++++++++++++++-- tests/conftest.py | 8 +++++++- 4 files changed, 34 insertions(+), 4 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 85939903..e4e07141 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -83,10 +83,16 @@ post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) +# will only use parameters defined in collections_get_request_model +collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch) + api = StacApi( settings=settings, extensions=extensions + [collection_search_extension], - client=CoreCrudClient(post_request_model=post_request_model), # type: ignore + client=CoreCrudClient( + post_request_model=post_request_model, # type: ignore + collection_request_model=collection_search_model, # type: ignore + ), response_class=ORJSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index e0b47cfa..d8afa18c 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -532,6 +532,9 @@ def clean_search_args( # noqa: C901 filter_query = to_cql2(parse_cql2_text(filter_query)) filter_lang = "cql2-json" + base_args["filter"] = orjson.loads(filter_query) + base_args["filter-lang"] = filter_lang + if datetime: base_args["datetime"] = format_datetime_range(datetime) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 79eecbb6..83589d08 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -12,7 +12,11 @@ from pystac import Collection, Extent, Item, SpatialExtent, TemporalExtent from stac_fastapi.api.app import StacApi from stac_fastapi.api.models import create_get_request_model, create_post_request_model -from stac_fastapi.extensions.core import FieldsExtension, TransactionExtension +from stac_fastapi.extensions.core import ( + CollectionSearchExtension, + FieldsExtension, + TransactionExtension, +) from stac_fastapi.types import stac as stac_types from stac_fastapi.pgstac.core import CoreCrudClient, Settings @@ -726,12 +730,23 @@ async def get_collection( ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) + collection_search_model = create_post_request_model( + extensions, base_model=PgstacSearch + ) + collection_search_extension = CollectionSearchExtension.from_extensions( + extensions=extensions + ) + collections_get_request_model = collection_search_extension.GET api = StacApi( - client=Client(post_request_model=post_request_model), + client=Client( + post_request_model=post_request_model, + collection_request_model=collection_search_model, + ), settings=settings, extensions=extensions, search_post_request_model=post_request_model, search_get_request_model=get_request_model, + collections_get_request_model=collections_get_request_model, ) app = api.app await connect_to_db(app) diff --git a/tests/conftest.py b/tests/conftest.py index fc63514f..e1dea234 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -151,11 +151,17 @@ def api_client(request, database): ) collections_get_request_model = collection_search_extension.GET + collection_search_model = create_post_request_model( + extensions, base_model=PgstacSearch + ) api = StacApi( settings=api_settings, extensions=extensions + [collection_search_extension], - client=CoreCrudClient(post_request_model=search_post_request_model), + client=CoreCrudClient( + post_request_model=search_post_request_model, + collection_request_model=collection_search_model, + ), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, From 7495544a881202f172eba8ddb08ca9f845b5846f Mon Sep 17 00:00:00 2001 From: hrodmn Date: Fri, 27 Sep 2024 06:46:27 -0500 Subject: [PATCH 14/19] recycle collections_get_request_model in client --- stac_fastapi/pgstac/app.py | 5 +---- stac_fastapi/pgstac/core.py | 20 +++++++++++++------- tests/api/test_api.py | 10 ++++------ tests/conftest.py | 5 +---- 4 files changed, 19 insertions(+), 21 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index e4e07141..7d22df8f 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -83,15 +83,12 @@ post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) -# will only use parameters defined in collections_get_request_model -collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch) - api = StacApi( settings=settings, extensions=extensions + [collection_search_extension], client=CoreCrudClient( post_request_model=post_request_model, # type: ignore - collection_request_model=collection_search_model, # type: ignore + collections_get_request_model=collections_get_request_model, # type: ignore ), response_class=ORJSONResponse, items_get_request_model=items_get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index d8afa18c..a7c322e9 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -1,5 +1,6 @@ """Item crud client.""" +import json import re from typing import Any, Dict, List, Optional, Set, Union from urllib.parse import unquote_plus, urljoin @@ -13,7 +14,7 @@ from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate -from stac_fastapi.api.models import JSONResponse +from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url @@ -38,7 +39,7 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - collection_request_model = attr.ib(default=PgstacSearch) + collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest) async def all_collections( # noqa: C901 self, @@ -83,7 +84,8 @@ async def all_collections( # noqa: C901 # Do the request try: - search_request = self.collection_request_model(**clean) + search_request = self.collections_get_request_model(**clean) + print(search_request) except ValidationError as e: raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" @@ -93,7 +95,7 @@ async def all_collections( # noqa: C901 async def _collection_search_base( # noqa: C901 self, - search_request: PgstacSearch, + search_request: APIRequest, request: Request, ) -> Collections: """Cross catalog search (GET). @@ -107,8 +109,12 @@ async def _collection_search_base( # noqa: C901 All collections which match the search criteria. """ base_url = get_base_url(request) - search_request_json = search_request.model_dump_json( - exclude_none=True, by_alias=True + search_request_json = json.dumps( + { + key: value + for key, value in search_request.__dict__.items() + if value is not None + } ) try: @@ -533,7 +539,7 @@ def clean_search_args( # noqa: C901 filter_lang = "cql2-json" base_args["filter"] = orjson.loads(filter_query) - base_args["filter-lang"] = filter_lang + base_args["filter_lang"] = filter_lang if datetime: base_args["datetime"] = format_datetime_range(datetime) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 83589d08..33335c13 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -730,23 +730,21 @@ async def get_collection( ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) - collection_search_model = create_post_request_model( - extensions, base_model=PgstacSearch - ) + collection_search_extension = CollectionSearchExtension.from_extensions( extensions=extensions ) - collections_get_request_model = collection_search_extension.GET + api = StacApi( client=Client( post_request_model=post_request_model, - collection_request_model=collection_search_model, + collections_get_request_model=collection_search_extension.GET, ), settings=settings, extensions=extensions, search_post_request_model=post_request_model, search_get_request_model=get_request_model, - collections_get_request_model=collections_get_request_model, + collections_get_request_model=collection_search_extension.GET, ) app = api.app await connect_to_db(app) diff --git a/tests/conftest.py b/tests/conftest.py index e1dea234..17bfdf63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -151,16 +151,13 @@ def api_client(request, database): ) collections_get_request_model = collection_search_extension.GET - collection_search_model = create_post_request_model( - extensions, base_model=PgstacSearch - ) api = StacApi( settings=api_settings, extensions=extensions + [collection_search_extension], client=CoreCrudClient( post_request_model=search_post_request_model, - collection_request_model=collection_search_model, + collections_get_request_model=collections_get_request_model, ), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, From 97adfdc4442f0069317fef0d80dec95c988179c0 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Fri, 27 Sep 2024 06:59:19 -0500 Subject: [PATCH 15/19] drop print statement --- stac_fastapi/pgstac/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index a7c322e9..7ffa0a03 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -85,7 +85,6 @@ async def all_collections( # noqa: C901 # Do the request try: search_request = self.collections_get_request_model(**clean) - print(search_request) except ValidationError as e: raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" From 3bb80f4b6cb3554928a7da47978fabe150f142f6 Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 1 Oct 2024 17:04:15 +0200 Subject: [PATCH 16/19] simplify --- stac_fastapi/pgstac/app.py | 4 ++- stac_fastapi/pgstac/core.py | 58 +++++++------------------------------ 2 files changed, 13 insertions(+), 49 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 7d22df8f..5dd37b1d 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -85,7 +85,9 @@ api = StacApi( settings=settings, - extensions=extensions + [collection_search_extension], + extensions=extensions + [collection_search_extension] + if collection_search_extension + else extensions, client=CoreCrudClient( post_request_model=post_request_model, # type: ignore collections_get_request_model=collections_get_request_model, # type: ignore diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index 7ffa0a03..a9d1ba1d 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -64,6 +64,7 @@ async def all_collections( # noqa: C901 Collections which match the search criteria, returns all collections by default. """ + base_url = get_base_url(request) # Parse request parameters base_args = { @@ -73,7 +74,7 @@ async def all_collections( # noqa: C901 "query": orjson.loads(unquote_plus(query)) if query else query, } - clean = clean_search_args( + clean_args = clean_search_args( base_args=base_args, datetime=datetime, fields=fields, @@ -82,53 +83,14 @@ async def all_collections( # noqa: C901 filter_lang=filter_lang, ) - # Do the request - try: - search_request = self.collections_get_request_model(**clean) - except ValidationError as e: - raise HTTPException( - status_code=400, detail=f"Invalid parameters provided {e}" - ) from e - - return await self._collection_search_base(search_request, request=request) - - async def _collection_search_base( # noqa: C901 - self, - search_request: APIRequest, - request: Request, - ) -> Collections: - """Cross catalog search (GET). - - Called with `GET /search`. - - Args: - search_request: search request parameters. - - Returns: - All collections which match the search criteria. - """ - base_url = get_base_url(request) - search_request_json = json.dumps( - { - key: value - for key, value in search_request.__dict__.items() - if value is not None - } - ) - - try: - async with request.app.state.get_connection(request, "r") as conn: - q, p = render( - """ - SELECT * FROM collection_search(:req::text::jsonb); - """, - req=search_request_json, - ) - collections_result: Collections = await conn.fetchval(q, *p) - except InvalidDatetimeFormatError as e: - raise InvalidQueryParameter( - f"Datetime parameter {search_request.datetime} is invalid." - ) from e + async with request.app.state.get_connection(request, "r") as conn: + q, p = render( + """ + SELECT * FROM collection_search(:req::text::jsonb); + """, + req=json.dumps(clean_args), + ) + collections_result: Collections = await conn.fetchval(q, *p) next: Optional[str] = None prev: Optional[str] = None From a37401ca3f48a3f8474c5221a219849e3396c0ce Mon Sep 17 00:00:00 2001 From: vincentsarago Date: Tue, 1 Oct 2024 17:11:05 +0200 Subject: [PATCH 17/19] remove unused --- stac_fastapi/pgstac/app.py | 5 +---- stac_fastapi/pgstac/core.py | 4 +--- tests/api/test_api.py | 5 +---- tests/conftest.py | 5 +---- 4 files changed, 4 insertions(+), 15 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 5dd37b1d..9946c08c 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -88,10 +88,7 @@ extensions=extensions + [collection_search_extension] if collection_search_extension else extensions, - client=CoreCrudClient( - post_request_model=post_request_model, # type: ignore - collections_get_request_model=collections_get_request_model, # type: ignore - ), + client=CoreCrudClient(post_request_model=post_request_model), # type: ignore response_class=ORJSONResponse, items_get_request_model=items_get_request_model, search_get_request_model=get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index a9d1ba1d..e7dcab21 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -14,7 +14,7 @@ from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate -from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse +from stac_fastapi.api.models import JSONResponse from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url @@ -39,8 +39,6 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest) - async def all_collections( # noqa: C901 self, request: Request, diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 33335c13..b135dce5 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -736,10 +736,7 @@ async def get_collection( ) api = StacApi( - client=Client( - post_request_model=post_request_model, - collections_get_request_model=collection_search_extension.GET, - ), + client=Client(post_request_model=post_request_model), settings=settings, extensions=extensions, search_post_request_model=post_request_model, diff --git a/tests/conftest.py b/tests/conftest.py index 17bfdf63..fc63514f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -155,10 +155,7 @@ def api_client(request, database): api = StacApi( settings=api_settings, extensions=extensions + [collection_search_extension], - client=CoreCrudClient( - post_request_model=search_post_request_model, - collections_get_request_model=collections_get_request_model, - ), + client=CoreCrudClient(post_request_model=search_post_request_model), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model, search_post_request_model=search_post_request_model, From 219eeb1c40645a910872e30ae928253d9a5694b6 Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 2 Oct 2024 08:45:07 -0500 Subject: [PATCH 18/19] clean up control flow for extension-specific logic --- stac_fastapi/pgstac/app.py | 46 ++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index 9946c08c..5e28cd09 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -49,36 +49,34 @@ "bulk_transactions": BulkTransactionExtension(client=BulkTransactionsClient()), } -if enabled_extensions := os.getenv("ENABLED_EXTENSIONS"): - _enabled_extensions = enabled_extensions.split(",") - extensions = [ - extension - for key, extension in extensions_map.items() - if key in _enabled_extensions - ] -else: - _enabled_extensions = list(extensions_map.keys()) + ["collection_search"] - extensions = list(extensions_map.values()) - - -if any(isinstance(ext, TokenPaginationExtension) for ext in extensions): - items_get_request_model = create_request_model( +enabled_extensions = ( + os.environ["ENABLED_EXTENSIONS"].split(",") + if "ENABLED_EXTENSIONS" in os.environ + else list(extensions_map.keys()) + ["collection_search"] +) +extensions = [ + extension for key, extension in extensions_map.items() if key in enabled_extensions +] + +items_get_request_model = ( + create_request_model( model_name="ItemCollectionUri", base_model=ItemCollectionUri, mixins=[TokenPaginationExtension().GET], request_type="GET", ) -else: - items_get_request_model = ItemCollectionUri + if any(isinstance(ext, TokenPaginationExtension) for ext in extensions) + else ItemCollectionUri +) -if "collection_search" in _enabled_extensions: - collection_search_extension = CollectionSearchExtension.from_extensions( - extensions=extensions - ) - collections_get_request_model = collection_search_extension.GET -else: - collection_search_extension = None - collections_get_request_model = EmptyRequest +collection_search_extension = ( + CollectionSearchExtension.from_extensions(extensions) + if "collection_search" in enabled_extensions + else None +) +collections_get_request_model = ( + collection_search_extension.GET if collection_search_extension else EmptyRequest +) post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) From 8d072459dec7b1bf65fa8d4dc2170bdf31bf887b Mon Sep 17 00:00:00 2001 From: hrodmn Date: Wed, 2 Oct 2024 09:08:18 -0500 Subject: [PATCH 19/19] add link to PR in changelog --- CHANGES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 1cfbb637..9aebcfc4 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,7 +3,7 @@ ## [Unreleased] - Fix Docker compose file, so example data can be loaded into database (author @zstatmanweil, ) -- Add collection search extension +- Add collection search extension ([#139](https://github.com/stac-utils/stac-fastapi-pgstac/pull/139)) - Fix `filter` extension implementation in `CoreCrudClient`