Skip to content

Commit

Permalink
add default CORS middleware (#441)
Browse files Browse the repository at this point in the history
* add default CORS middleware

* test that default cors middleware is working
  • Loading branch information
geospatial-jeff authored Aug 4, 2022
1 parent 0d36b76 commit 397acac
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 3 deletions.
6 changes: 4 additions & 2 deletions stac_fastapi/api/stac_fastapi/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from starlette.responses import JSONResponse, Response

from stac_fastapi.api.errors import DEFAULT_STATUS_CODES, add_exception_handlers
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
from stac_fastapi.api.middleware import CORSMiddleware, ProxyHeaderMiddleware
from stac_fastapi.api.models import (
APIRequest,
CollectionUri,
Expand Down Expand Up @@ -93,7 +93,9 @@ class StacApi:
pagination_extension = attr.ib(default=TokenPaginationExtension)
response_class: Type[Response] = attr.ib(default=JSONResponse)
middlewares: List = attr.ib(
default=attr.Factory(lambda: [BrotliMiddleware, ProxyHeaderMiddleware])
default=attr.Factory(
lambda: [BrotliMiddleware, CORSMiddleware, ProxyHeaderMiddleware]
)
)
route_dependencies: List[Tuple[List[Scope], List[Depends]]] = attr.ib(default=[])

Expand Down
38 changes: 37 additions & 1 deletion stac_fastapi/api/stac_fastapi/api/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,48 @@
"""api middleware."""

import re
import typing
from http.client import HTTP_PORT, HTTPS_PORT
from typing import List, Tuple

from starlette.middleware.cors import CORSMiddleware as _CORSMiddleware
from starlette.types import ASGIApp, Receive, Scope, Send


class CORSMiddleware(_CORSMiddleware):
"""
Subclass of Starlette's standard CORS middleware with default values set to those reccomended by the STAC API spec.
https://github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors
"""

def __init__(
self,
app: ASGIApp,
allow_origins: typing.Sequence[str] = ("*",),
allow_methods: typing.Sequence[str] = (
"OPTIONS",
"POST",
"GET",
),
allow_headers: typing.Sequence[str] = ("Content-Type",),
allow_credentials: bool = False,
allow_origin_regex: typing.Optional[str] = None,
expose_headers: typing.Sequence[str] = (),
max_age: int = 600,
) -> None:
"""Create CORS middleware."""
super().__init__(
app,
allow_origins,
allow_methods,
allow_headers,
allow_credentials,
allow_origin_regex,
expose_headers,
max_age,
)


class ProxyHeaderMiddleware:
"""
Account for forwarding headers when deriving base URL.
Expand Down
19 changes: 19 additions & 0 deletions stac_fastapi/api/tests/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from unittest import mock

import pytest
from starlette.applications import Starlette
from starlette.testclient import TestClient

from stac_fastapi.api.app import StacApi
from stac_fastapi.api.middleware import ProxyHeaderMiddleware
from stac_fastapi.types.config import ApiSettings
from stac_fastapi.types.core import BaseCoreClient


@pytest.fixture
Expand All @@ -10,6 +16,13 @@ def proxy_header_middleware() -> ProxyHeaderMiddleware:
return ProxyHeaderMiddleware(app)


@pytest.fixture
def test_client() -> TestClient:
app = StacApi(settings=ApiSettings(), client=mock.create_autospec(BaseCoreClient))
with TestClient(app.app) as client:
yield client


@pytest.mark.parametrize(
"headers,key,expected",
[
Expand Down Expand Up @@ -138,3 +151,9 @@ def test_get_forwarded_url_parts(
):
actual = proxy_header_middleware._get_forwarded_url_parts(scope)
assert actual == expected


def test_cors_middleware(test_client):
resp = test_client.get("/_mgmt/ping", headers={"Origin": "http://netloc"})
assert resp.status_code == 200
assert resp.headers["access-control-allow-origin"] == "*"

0 comments on commit 397acac

Please sign in to comment.