diff --git a/.github/workflows/cicd.yaml b/.github/workflows/cicd.yaml index 514e31496..f86cb6786 100644 --- a/.github/workflows/cicd.yaml +++ b/.github/workflows/cicd.yaml @@ -76,3 +76,45 @@ jobs: - uses: actions/checkout@v4 - name: Test generating docs run: make docs + + benchmark: + needs: [test] + runs-on: ubuntu-20.04 + steps: + - name: Check out repository code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install types + run: | + python -m pip install ./stac_fastapi/types[dev] + + - name: Install extensions + run: | + python -m pip install ./stac_fastapi/extensions + + - name: Install core api + run: | + python -m pip install ./stac_fastapi/api[dev,benchmark] + + - name: Run Benchmark + run: python -m pytest stac_fastapi/api/tests/benchmarks.py --benchmark-only --benchmark-columns 'min, max, mean, median' --benchmark-json output.json + + - name: Store and benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: STAC FastAPI Benchmarks + tool: 'pytest' + output-file-path: output.json + alert-threshold: '130%' + comment-on-alert: true + fail-on-alert: false + # GitHub API token to make a commit comment + github-token: ${{ secrets.GITHUB_TOKEN }} + gh-pages-branch: 'gh-benchmarks' + # Make a commit only if main + auto-push: ${{ github.ref == 'refs/heads/main' }} diff --git a/CHANGES.md b/CHANGES.md index 16bc9a809..bc6475fe1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -2,6 +2,18 @@ ## [Unreleased] +### Changed + +* Make sure FastAPI uses Pydantic validation and serialization by not wrapping endpoint output with a Response object ([#650](https://github.com/stac-utils/stac-fastapi/pull/650)) + +### Removed + +* Deprecate `response_class` option in `stac_fastapi.api.routes.create_async_endpoint` method ([#650](https://github.com/stac-utils/stac-fastapi/pull/650)) + +### Added + +* Add benchmark in CI ([#650](https://github.com/stac-utils/stac-fastapi/pull/650)) + ## [2.4.9] - 2023-11-17 ### Added diff --git a/stac_fastapi/api/setup.py b/stac_fastapi/api/setup.py index 1e3b8002f..9dfa86ac9 100644 --- a/stac_fastapi/api/setup.py +++ b/stac_fastapi/api/setup.py @@ -23,6 +23,9 @@ "requests", "pystac[validation]==1.*", ], + "benchmark": [ + "pytest-benchmark", + ], "docs": ["mkdocs", "mkdocs-material", "pdocs"], } diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index 28fff912c..557896d8f 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -132,9 +132,7 @@ def register_landing_page(self): response_model_exclude_unset=False, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.landing_page, EmptyRequest, self.response_class - ), + endpoint=create_async_endpoint(self.client.landing_page, EmptyRequest), ) def register_conformance_classes(self): @@ -153,9 +151,7 @@ def register_conformance_classes(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.conformance, EmptyRequest, self.response_class - ), + endpoint=create_async_endpoint(self.client.conformance, EmptyRequest), ) def register_get_item(self): @@ -172,9 +168,7 @@ def register_get_item(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.get_item, ItemUri, GeoJSONResponse - ), + endpoint=create_async_endpoint(self.client.get_item, ItemUri), ) def register_post_search(self): @@ -195,7 +189,7 @@ def register_post_search(self): response_model_exclude_none=True, methods=["POST"], endpoint=create_async_endpoint( - self.client.post_search, self.search_post_request_model, GeoJSONResponse + self.client.post_search, self.search_post_request_model ), ) @@ -217,7 +211,7 @@ def register_get_search(self): response_model_exclude_none=True, methods=["GET"], endpoint=create_async_endpoint( - self.client.get_search, self.search_get_request_model, GeoJSONResponse + self.client.get_search, self.search_get_request_model ), ) @@ -237,9 +231,7 @@ def register_get_collections(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.all_collections, EmptyRequest, self.response_class - ), + endpoint=create_async_endpoint(self.client.all_collections, EmptyRequest), ) def register_get_collection(self): @@ -256,9 +248,7 @@ def register_get_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.get_collection, CollectionUri, self.response_class - ), + endpoint=create_async_endpoint(self.client.get_collection, CollectionUri), ) def register_get_item_collection(self): @@ -287,9 +277,7 @@ def register_get_item_collection(self): response_model_exclude_unset=True, response_model_exclude_none=True, methods=["GET"], - endpoint=create_async_endpoint( - self.client.item_collection, request_model, GeoJSONResponse - ), + endpoint=create_async_endpoint(self.client.item_collection, request_model), ) def register_core(self): diff --git a/stac_fastapi/api/stac_fastapi/api/routes.py b/stac_fastapi/api/stac_fastapi/api/routes.py index f4eb759af..66b76d2d7 100644 --- a/stac_fastapi/api/stac_fastapi/api/routes.py +++ b/stac_fastapi/api/stac_fastapi/api/routes.py @@ -1,6 +1,8 @@ """Route factories.""" + import functools import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Type, TypedDict, Union from fastapi import Depends, params @@ -8,18 +10,16 @@ from pydantic import BaseModel from starlette.concurrency import run_in_threadpool from starlette.requests import Request -from starlette.responses import JSONResponse, Response +from starlette.responses import Response from starlette.routing import BaseRoute, Match from starlette.status import HTTP_204_NO_CONTENT from stac_fastapi.api.models import APIRequest -def _wrap_response(resp: Any, response_class: Type[Response]) -> Response: - if isinstance(resp, Response): +def _wrap_response(resp: Any) -> Any: + if resp is not None: return resp - elif resp is not None: - return response_class(resp) else: # None is returned as 204 No Content return Response(status_code=HTTP_204_NO_CONTENT) @@ -37,12 +37,19 @@ async def run(*args, **kwargs): def create_async_endpoint( func: Callable, request_model: Union[Type[APIRequest], Type[BaseModel], Dict], - response_class: Type[Response] = JSONResponse, + response_class: Optional[Type[Response]] = None, ): """Wrap a function in a coroutine which may be used to create a FastAPI endpoint. Synchronous functions are executed asynchronously using a background thread. """ + + if response_class: + warnings.warns( + "`response_class` option is deprecated, please set the Response class directly in the endpoint.", # noqa: E501 + DeprecationWarning, + ) + if not inspect.iscoroutinefunction(func): func = sync_to_async(func) @@ -53,9 +60,7 @@ async def _endpoint( request_data: request_model = Depends(), # type:ignore ): """Endpoint.""" - return _wrap_response( - await func(request=request, **request_data.kwargs()), response_class - ) + return _wrap_response(await func(request=request, **request_data.kwargs())) elif issubclass(request_model, BaseModel): @@ -64,9 +69,7 @@ async def _endpoint( request_data: request_model, # type:ignore ): """Endpoint.""" - return _wrap_response( - await func(request_data, request=request), response_class - ) + return _wrap_response(await func(request_data, request=request)) else: @@ -75,9 +78,7 @@ async def _endpoint( request_data: Dict[str, Any], # type:ignore ): """Endpoint.""" - return _wrap_response( - await func(request_data, request=request), response_class - ) + return _wrap_response(await func(request_data, request=request)) return _endpoint diff --git a/stac_fastapi/api/tests/benchmarks.py b/stac_fastapi/api/tests/benchmarks.py new file mode 100644 index 000000000..3a194057d --- /dev/null +++ b/stac_fastapi/api/tests/benchmarks.py @@ -0,0 +1,172 @@ +from datetime import datetime +from typing import List, Optional, Union + +import pytest +from stac_pydantic.api.utils import link_factory +from starlette.testclient import TestClient + +from stac_fastapi.api.app import StacApi +from stac_fastapi.types import stac as stac_types +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient, BaseSearchPostRequest, NumType + +collection_links = link_factory.CollectionLinks("/", "test").create_links() +item_links = link_factory.ItemLinks("/", "test", "test").create_links() + + +collections = [ + stac_types.Collection( + id=f"test_collection_{n}", + title="Test Collection", + description="A test collection", + keywords=["test"], + license="proprietary", + extent={ + "spatial": {"bbox": [[-180, -90, 180, 90]]}, + "temporal": {"interval": [["2000-01-01T00:00:00Z", None]]}, + }, + links=collection_links.dict(exclude_none=True), + ) + for n in range(0, 10) +] + +items = [ + stac_types.Item( + id=f"test_item_{n}", + type="Feature", + geometry={"type": "Point", "coordinates": [0, 0]}, + bbox=[-180, -90, 180, 90], + properties={"datetime": "2000-01-01T00:00:00Z"}, + links=item_links.dict(exclude_none=True), + assets={}, + ) + for n in range(0, 1000) +] + + +class CoreClient(BaseCoreClient): + def post_search( + self, search_request: BaseSearchPostRequest, **kwargs + ) -> stac_types.ItemCollection: + raise NotImplementedError + + def get_search( + self, + collections: Optional[List[str]] = None, + ids: Optional[List[str]] = None, + bbox: Optional[List[NumType]] = None, + intersects: Optional[str] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: Optional[int] = 10, + **kwargs, + ) -> stac_types.ItemCollection: + raise NotImplementedError + + def get_item(self, item_id: str, collection_id: str, **kwargs) -> stac_types.Item: + raise NotImplementedError + + def all_collections(self, **kwargs) -> stac_types.Collections: + return stac_types.Collections( + collections=collections, + links=[ + {"href": "test", "rel": "root"}, + {"href": "test", "rel": "self"}, + {"href": "test", "rel": "parent"}, + ], + ) + + def get_collection(self, collection_id: str, **kwargs) -> stac_types.Collection: + return collections[0] + + def item_collection( + self, + collection_id: str, + bbox: Optional[List[Union[float, int]]] = None, + datetime: Optional[Union[str, datetime]] = None, + limit: int = 10, + token: str = None, + **kwargs, + ) -> stac_types.ItemCollection: + return stac_types.ItemCollection( + type="FeatureCollection", + features=items[0:limit], + links=[ + {"href": "test", "rel": "root"}, + {"href": "test", "rel": "self"}, + {"href": "test", "rel": "parent"}, + ], + ) + + +@pytest.fixture(autouse=True) +def client_validation() -> TestClient: + settings = ApiSettings(enable_response_models=True) + app = StacApi(settings=settings, client=CoreClient()) + with TestClient(app.app) as client: + yield client + + +@pytest.fixture(autouse=True) +def client_no_validation() -> TestClient: + settings = ApiSettings(enable_response_models=False) + app = StacApi(settings=settings, client=CoreClient()) + with TestClient(app.app) as client: + yield client + + +@pytest.mark.parametrize("limit", [1, 10, 50, 100, 200, 250, 1000]) +@pytest.mark.parametrize("validate", [True, False]) +def test_benchmark_items( + benchmark, client_validation, client_no_validation, validate, limit +): + """Benchmark items endpoint.""" + params = {"limit": limit} + + def f(p): + if validate: + return client_validation.get("/collections/fake_collection/items", params=p) + else: + return client_no_validation.get( + "/collections/fake_collection/items", params=p + ) + + benchmark.group = "Items With Model validation" if validate else "Items" + + response = benchmark(f, params) + assert response.status_code == 200 + + +@pytest.mark.parametrize("validate", [True, False]) +def test_benchmark_collection( + benchmark, client_validation, client_no_validation, validate +): + """Benchmark items endpoint.""" + + def f(): + if validate: + return client_validation.get("/collections/fake_collection") + else: + return client_no_validation.get("/collections/fake_collection") + + benchmark.group = "Collection With Model validation" if validate else "Collection" + + response = benchmark(f) + assert response.status_code == 200 + + +@pytest.mark.parametrize("validate", [True, False]) +def test_benchmark_collections( + benchmark, client_validation, client_no_validation, validate +): + """Benchmark items endpoint.""" + + def f(): + if validate: + return client_validation.get("/collections") + else: + return client_no_validation.get("/collections") + + benchmark.group = "Collections With Model validation" if validate else "Collections" + + response = benchmark(f) + assert response.status_code == 200