diff --git a/stac_fastapi/pgstac/app.py b/stac_fastapi/pgstac/app.py index e4e07141..7d22df8f 100644 --- a/stac_fastapi/pgstac/app.py +++ b/stac_fastapi/pgstac/app.py @@ -83,15 +83,12 @@ post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) -# will only use parameters defined in collections_get_request_model -collection_search_model = create_post_request_model(extensions, base_model=PgstacSearch) - api = StacApi( settings=settings, extensions=extensions + [collection_search_extension], client=CoreCrudClient( post_request_model=post_request_model, # type: ignore - collection_request_model=collection_search_model, # type: ignore + collections_get_request_model=collections_get_request_model, # type: ignore ), response_class=ORJSONResponse, items_get_request_model=items_get_request_model, diff --git a/stac_fastapi/pgstac/core.py b/stac_fastapi/pgstac/core.py index d8afa18c..a7c322e9 100644 --- a/stac_fastapi/pgstac/core.py +++ b/stac_fastapi/pgstac/core.py @@ -1,5 +1,6 @@ """Item crud client.""" +import json import re from typing import Any, Dict, List, Optional, Set, Union from urllib.parse import unquote_plus, urljoin @@ -13,7 +14,7 @@ from pygeofilter.backends.cql2_json import to_cql2 from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from pypgstac.hydration import hydrate -from stac_fastapi.api.models import JSONResponse +from stac_fastapi.api.models import APIRequest, EmptyRequest, JSONResponse from stac_fastapi.types.core import AsyncBaseCoreClient, Relations from stac_fastapi.types.errors import InvalidQueryParameter, NotFoundError from stac_fastapi.types.requests import get_base_url @@ -38,7 +39,7 @@ class CoreCrudClient(AsyncBaseCoreClient): """Client for core endpoints defined by stac.""" - collection_request_model = attr.ib(default=PgstacSearch) + collections_get_request_model: APIRequest = attr.ib(default=EmptyRequest) async def all_collections( # noqa: C901 self, @@ -83,7 +84,8 @@ async def all_collections( # noqa: C901 # Do the request try: - search_request = self.collection_request_model(**clean) + search_request = self.collections_get_request_model(**clean) + print(search_request) except ValidationError as e: raise HTTPException( status_code=400, detail=f"Invalid parameters provided {e}" @@ -93,7 +95,7 @@ async def all_collections( # noqa: C901 async def _collection_search_base( # noqa: C901 self, - search_request: PgstacSearch, + search_request: APIRequest, request: Request, ) -> Collections: """Cross catalog search (GET). @@ -107,8 +109,12 @@ async def _collection_search_base( # noqa: C901 All collections which match the search criteria. """ base_url = get_base_url(request) - search_request_json = search_request.model_dump_json( - exclude_none=True, by_alias=True + search_request_json = json.dumps( + { + key: value + for key, value in search_request.__dict__.items() + if value is not None + } ) try: @@ -533,7 +539,7 @@ def clean_search_args( # noqa: C901 filter_lang = "cql2-json" base_args["filter"] = orjson.loads(filter_query) - base_args["filter-lang"] = filter_lang + base_args["filter_lang"] = filter_lang if datetime: base_args["datetime"] = format_datetime_range(datetime) diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 83589d08..33335c13 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -730,23 +730,21 @@ async def get_collection( ] post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) - collection_search_model = create_post_request_model( - extensions, base_model=PgstacSearch - ) + collection_search_extension = CollectionSearchExtension.from_extensions( extensions=extensions ) - collections_get_request_model = collection_search_extension.GET + api = StacApi( client=Client( post_request_model=post_request_model, - collection_request_model=collection_search_model, + collections_get_request_model=collection_search_extension.GET, ), settings=settings, extensions=extensions, search_post_request_model=post_request_model, search_get_request_model=get_request_model, - collections_get_request_model=collections_get_request_model, + collections_get_request_model=collection_search_extension.GET, ) app = api.app await connect_to_db(app) diff --git a/tests/conftest.py b/tests/conftest.py index e1dea234..17bfdf63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -151,16 +151,13 @@ def api_client(request, database): ) collections_get_request_model = collection_search_extension.GET - collection_search_model = create_post_request_model( - extensions, base_model=PgstacSearch - ) api = StacApi( settings=api_settings, extensions=extensions + [collection_search_extension], client=CoreCrudClient( post_request_model=search_post_request_model, - collection_request_model=collection_search_model, + collections_get_request_model=collections_get_request_model, ), items_get_request_model=items_get_request_model, search_get_request_model=search_get_request_model,