diff --git a/CHANGES.md b/CHANGES.md index ea020992..6f2106eb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,13 +2,23 @@ ## [Unreleased] - TBD +### Added + +* Add attributes to `stac_fastapi.api.app.StacApi` to enable customization of request model for: + - `/collections`: **collections_get_request_model**, default to `EmptyRequest` + - `/collections/{collection_id}`: **collection_get_request_model**, default to `CollectionUri` + - `/collections/{collection_id}/items`: **items_get_request_model**, default to `ItemCollectionUri` + - `/collections/{collection_id}/items/{item_id}`: **item_get_request_model**, default to `ItemUri` + ### Fixed * Updated default filter language in filter extension's POST search request model to match the extension's documentation [#711](https://github.com/stac-utils/stac-fastapi/issues/711) ### Removed -* Removed the Filter Extension depenency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716) +* Removed the Filter Extension dependency from `AggregationExtensionPostRequest` and `AggregationExtensionGetRequest` [#716](https://github.com/stac-utils/stac-fastapi/pull/716) +* `pagination_extension` attribute in `stac_fastapi.api.app.StacApi` +* remove use of `pagination_extension` in `register_get_item_collection` function (User now need to construct the request model and pass it using `items_get_request_model` attribute) ## [3.0.0a3] - 2024-06-13 diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 5fe7f9d0..aa4a0cc0 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -18,18 +18,18 @@ from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( + APIRequest, CollectionUri, EmptyRequest, GeoJSONResponse, ItemCollectionUri, ItemUri, - create_request_model, ) from stac_fastapi.api.openapi import update_openapi from stac_fastapi.api.routes import Scope, add_route_dependencies, create_async_endpoint # TODO: make this module not depend on `stac_fastapi.extensions` -from stac_fastapi.extensions.core import FieldsExtension, TokenPaginationExtension +from stac_fastapi.extensions.core import FieldsExtension from stac_fastapi.types.config import ApiSettings, Settings from stac_fastapi.types.core import AsyncBaseCoreClient, BaseCoreClient from stac_fastapi.types.extension import ApiExtension @@ -108,7 +108,10 @@ class StacApi: search_post_request_model: Type[BaseSearchPostRequest] = attr.ib( default=BaseSearchPostRequest ) - pagination_extension = attr.ib(default=TokenPaginationExtension) + collections_get_request_model: Type[APIRequest] = attr.ib(default=EmptyRequest) + collection_get_request_model: Type[APIRequest] = attr.ib(default=CollectionUri) + items_get_request_model: Type[APIRequest] = attr.ib(default=ItemCollectionUri) + item_get_request_model: Type[APIRequest] = attr.ib(default=ItemUri) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List[Middleware] = attr.ib( default=attr.Factory( @@ -211,7 +214,9 @@ def register_get_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.get_item, ItemUri), + endpoint=create_async_endpoint( + self.client.get_item, self.item_get_request_model + ), ) def register_post_search(self): @@ -302,7 +307,9 @@ def register_get_collections(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.all_collections, EmptyRequest), + endpoint=create_async_endpoint( + self.client.all_collections, self.collections_get_request_model + ), ) def register_get_collection(self): @@ -329,7 +336,9 @@ def register_get_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.get_collection, CollectionUri), + endpoint=create_async_endpoint( + self.client.get_collection, self.collection_get_request_model + ), ) def register_get_item_collection(self): @@ -338,16 +347,6 @@ def register_get_item_collection(self): Returns: None """ - pagination_extension = self.get_extension(self.pagination_extension) - if pagination_extension is not None: - mixins = [pagination_extension.GET] - else: - mixins = None - request_model = create_request_model( - "ItemCollectionURI", - base_model=ItemCollectionUri, - mixins=mixins, - ) self.router.add_api_route( name="Get ItemCollection", path="/collections/{collection_id}/items", @@ -366,7 +365,9 @@ def register_get_item_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint(self.client.item_collection, request_model), + endpoint=create_async_endpoint( + self.client.item_collection, self.items_get_request_model + ), ) def register_core(self): diff --git a/stac_fastapi/api/tests/test_api.py b/stac_fastapi/api/tests/test_api.py index d559a377..7db4d9a5 100644 --- a/stac_fastapi/api/tests/test_api.py +++ b/stac_fastapi/api/tests/test_api.py @@ -2,6 +2,7 @@ from starlette.testclient import TestClient from stac_fastapi.api.app import StacApi +from stac_fastapi.api.models import ItemCollectionUri, create_request_model from stac_fastapi.extensions.core import ( TokenPaginationExtension, TransactionExtension, @@ -13,6 +14,13 @@ class TestRouteDependencies: @staticmethod def _build_api(**overrides): settings = config.ApiSettings() + + items_get_request_model = create_request_model( + "ItemCollectionURI", + base_model=ItemCollectionUri, + mixins=[TokenPaginationExtension().GET], + ) + return StacApi( **{ "settings": settings, @@ -23,6 +31,7 @@ def _build_api(**overrides): ), TokenPaginationExtension(), ], + "items_get_request_model": items_get_request_model, **overrides, } ) diff --git a/stac_fastapi/api/tests/test_app.py b/stac_fastapi/api/tests/test_app.py index 829982b5..64695de8 100644 --- a/stac_fastapi/api/tests/test_app.py +++ b/stac_fastapi/api/tests/test_app.py @@ -1,13 +1,19 @@ +from dataclasses import dataclass from datetime import datetime from typing import List, Optional, Union import pytest +from fastapi import Path, Query from fastapi.testclient import TestClient from pydantic import ValidationError from stac_pydantic import api from stac_fastapi.api import app -from stac_fastapi.api.models import create_get_request_model, create_post_request_model +from stac_fastapi.api.models import ( + APIRequest, + create_get_request_model, + create_post_request_model, +) from stac_fastapi.extensions.core import FieldsExtension, FilterExtension from stac_fastapi.types import stac from stac_fastapi.types.config import ApiSettings @@ -294,3 +300,64 @@ def item_collection( else: assert get_search.status_code == 200, get_search.text assert post_search.status_code == 200, post_search.text + + +def test_request_model(AsyncTestCoreClient): + """Test if request models are passed correctly.""" + + @dataclass + class CollectionsRequest(APIRequest): + user: str = Query(...) + + @dataclass + class CollectionRequest(APIRequest): + collection_id: str = Path(description="Collection ID") + user: str = Query(...) + + @dataclass + class ItemsRequest(APIRequest): + collection_id: str = Path(description="Collection ID") + user: str = Query(...) + + @dataclass + class ItemRequest(APIRequest): + collection_id: str = Path(description="Collection ID") + item_id: str = Path(description="Item ID") + user: str = Query(...) + + test_app = app.StacApi( + settings=ApiSettings(), + client=AsyncTestCoreClient(), + collections_get_request_model=CollectionsRequest, + collection_get_request_model=CollectionRequest, + items_get_request_model=ItemsRequest, + item_get_request_model=ItemRequest, + extensions=[], + ) + + with TestClient(test_app.app) as client: + resp = client.get("/collections") + assert resp.status_code == 400 + + resp = client.get("/collections", params={"user": "luke"}) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection") + assert resp.status_code == 400 + + resp = client.get("/collections/test_collection", params={"user": "luke"}) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection/items") + assert resp.status_code == 400 + + resp = client.get("/collections/test_collection/items", params={"user": "luke"}) + assert resp.status_code == 200 + + resp = client.get("/collections/test_collection/items/test_item") + assert resp.status_code == 400 + + resp = client.get( + "/collections/test_collection/items/test_item", params={"user": "luke"} + ) + assert resp.status_code == 200