diff --git a/stac_fastapi/api/stac_fastapi/api/app.py b/stac_fastapi/api/stac_fastapi/api/app.py index eee9ad748..d18844e5c 100644 --- a/stac_fastapi/api/stac_fastapi/api/app.py +++ b/stac_fastapi/api/stac_fastapi/api/app.py @@ -14,7 +14,7 @@ from starlette.responses import JSONResponse, Response from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers -from stac_fastapi.api.middleware import ProxyHeaderMiddleware +from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware from stac_fastapi.api.models import ( APIRequest, CollectionUri, @@ -93,7 +93,9 @@ class StacApi: pagination_extension = attr.ib(default=TokenPaginationExtension) response_class: Type[Response] = attr.ib(default=JSONResponse) middlewares: List = attr.ib( - default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware]) + default=attr.Factory( + lambda: [BrotliMiddleware, CORSMiddleware, ProxyHeaderMiddleware] + ) ) route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[]) diff --git a/stac_fastapi/api/stac_fastapi/api/middleware.py b/stac_fastapi/api/stac_fastapi/api/middleware.py index 9803e7ab0..21b5581b8 100644 --- a/stac_fastapi/api/stac_fastapi/api/middleware.py +++ b/stac_fastapi/api/stac_fastapi/api/middleware.py @@ -1,12 +1,48 @@ """api middleware.""" - import re +import typing from http.client import HTTP_PORT, HTTPS_PORT from typing import List, Tuple +from starlette.middleware.cors import CORSMiddleware as _CORSMiddleware from starlette.types import ASGIApp, Receive, Scope, Send +class CORSMiddleware(_CORSMiddleware): + """ + Subclass of Starlette's standard CORS middleware with default values set to those reccomended by the STAC API spec. + + https://github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors + """ + + def __init__( + self, + app: ASGIApp, + allow_origins: typing.Sequence[str] = ("*",), + allow_methods: typing.Sequence[str] = ( + "OPTIONS", + "POST", + "GET", + ), + allow_headers: typing.Sequence[str] = ("Content-Type",), + allow_credentials: bool = False, + allow_origin_regex: typing.Optional[str] = None, + expose_headers: typing.Sequence[str] = (), + max_age: int = 600, + ) -> None: + """Create CORS middleware.""" + super().__init__( + app, + allow_origins, + allow_methods, + allow_headers, + allow_credentials, + allow_origin_regex, + expose_headers, + max_age, + ) + + class ProxyHeaderMiddleware: """ Account for forwarding headers when deriving base URL. diff --git a/stac_fastapi/api/tests/test_middleware.py b/stac_fastapi/api/tests/test_middleware.py index cfe299328..e3e90bed4 100644 --- a/stac_fastapi/api/tests/test_middleware.py +++ b/stac_fastapi/api/tests/test_middleware.py @@ -1,7 +1,13 @@ +from unittest import mock + import pytest from starlette.applications import Starlette +from starlette.testclient import TestClient +from stac_fastapi.api.app import StacApi from stac_fastapi.api.middleware import ProxyHeaderMiddleware +from stac_fastapi.types.config import ApiSettings +from stac_fastapi.types.core import BaseCoreClient @pytest.fixture @@ -10,6 +16,13 @@ def proxy_header_middleware() -> ProxyHeaderMiddleware: return ProxyHeaderMiddleware(app) +@pytest.fixture +def test_client() -> TestClient: + app = StacApi(settings=ApiSettings(), client=mock.create_autospec(BaseCoreClient)) + with TestClient(app.app) as client: + yield client + + @pytest.mark.parametrize( "headers,key,expected", [ @@ -138,3 +151,9 @@ def test_get_forwarded_url_parts( ): actual = proxy_header_middleware._get_forwarded_url_parts(scope) assert actual == expected + + +def test_cors_middleware(test_client): + resp = test_client.get("/_mgmt/ping", headers={"Origin": "http://netloc"}) + assert resp.status_code == 200 + assert resp.headers["access-control-allow-origin"] == "*"