Skip to content

Commit

Permalink
fix: handle deeply nested security dependencies (#14)
Browse files Browse the repository at this point in the history
* Add failing test

* bugfix: handle deeply nested security dependencies

* bugfix: prevent env pollution

* expand tests for complete coverage
  • Loading branch information
alukach committed Dec 8, 2024
1 parent 44116b7 commit 5ff51ca
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 23 deletions.
31 changes: 17 additions & 14 deletions src/stac_auth_proxy/handlers/open_api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi import Request, Response
from fastapi.routing import APIRoute

from ..utils import safe_headers
from ..utils import has_any_security_requirements, safe_headers
from .reverse_proxy import ReverseProxyHandler

logger = logging.getLogger(__name__)
Expand All @@ -28,26 +28,29 @@ async def dispatch(self, req: Request, res: Response):
# Pass along the response headers
res.headers.update(safe_headers(oidc_spec_response.headers))

# Add the OIDC security scheme to the components
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
self.auth_scheme_name
] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}

proxy_auth_routes = [
r
for r in req.app.routes
# Ignore non-APIRoutes (we can't check their security dependencies)
if isinstance(r, APIRoute)
# Ignore routes that don't have security requirements
and (
r.dependant.security_requirements
or any(d.security_requirements for d in r.dependant.dependencies)
)
and has_any_security_requirements(r.dependant)
]

if not proxy_auth_routes:
logger.warning(
"No routes with security requirements found. OIDC security requirements will not be added."
)
return openapi_spec

# Add the OIDC security scheme to the components
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
self.auth_scheme_name
] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}

# Update the paths with the specified security requirements
for path, method_config in openapi_spec["paths"].items():
for method, config in method_config.items():
Expand All @@ -59,7 +62,7 @@ async def dispatch(self, req: Request, res: Response):
continue
# Add the OIDC security requirement
config.setdefault("security", []).append(
[{self.auth_scheme_name: []}]
{self.auth_scheme_name: []}
)
break

Expand Down
13 changes: 13 additions & 0 deletions src/stac_auth_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from urllib.parse import urlparse

from fastapi.dependencies.models import Dependant
from httpx import Headers


Expand All @@ -29,3 +30,15 @@ def extract_variables(url: str) -> dict:
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
match = re.match(pattern, path)
return {k: v for k, v in match.groupdict().items() if v} if match else {}


def has_any_security_requirements(dependency: Dependant) -> bool:
"""
Recursively check if any dependency within the hierarchy has a non-empty
security_requirements list.
"""
if dependency.security_requirements:
return True
return any(
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
)
47 changes: 39 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pytest fixtures."""

import json
import os
import threading
from typing import Any
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -68,19 +69,42 @@ def source_api():
app = FastAPI(docs_url="/api.html", openapi_url="/api")

for path, methods in {
"/": ["GET"],
"/conformance": ["GET"],
"/queryables": ["GET"],
"/search": ["GET", "POST"],
"/collections": ["GET", "POST"],
"/collections/{collection_id}": ["GET", "PUT", "DELETE"],
"/collections/{collection_id}/items": ["GET", "POST"],
"/": [
"GET",
],
"/conformance": [
"GET",
],
"/queryables": [
"GET",
],
"/search": [
"GET",
"POST",
],
"/collections": [
"GET",
"POST",
],
"/collections/{collection_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/items": [
"GET",
"POST",
],
"/collections/{collection_id}/items/{item_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/bulk_items": ["POST"],
"/collections/{collection_id}/bulk_items": [
"POST",
],
}.items():
for method in methods:
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
Expand Down Expand Up @@ -109,3 +133,10 @@ def source_api_server(source_api):
yield f"http://{host}:{port}"
server.should_exit = True
thread.join()


@pytest.fixture(autouse=True, scope="module")
def mock_env():
"""Clear environment variables to avoid poluting configs from runtime env."""
with patch.dict(os.environ, clear=True):
yield
56 changes: 55 additions & 1 deletion tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)


def test_no_edit_openapi_spec(source_api_server):
def test_no_openapi_spec_endpoint(source_api_server):
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
Expand All @@ -25,6 +25,24 @@ def test_no_edit_openapi_spec(source_api_server):
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})


def test_no_private_endpoints(source_api_server):
"""When no endpoints are private, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint="/api",
private_endpoints={},
default_public=True,
)
client = TestClient(app)
response = client.get("/api")
assert response.status_code == 200
openapi = response.json()
assert "info" in openapi
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})


def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
app = app_factory(
Expand All @@ -39,3 +57,39 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})


def test_oidc_in_openapi_spec_private_endpoints(
source_api: FastAPI, source_api_server: str
):
"""When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec."""
private_endpoints = {
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
"/collections": ["POST"],
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
"/collections/{collection_id}/items": ["POST"],
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
"/collections/{collection_id}/bulk_items": ["POST"],
}
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
private_endpoints=private_endpoints,
)
client = TestClient(app)
openapi = client.get(source_api.openapi_url).raise_for_status().json()
for path, methods in private_endpoints.items():
for method in methods:
openapi_path = openapi["paths"].get(path)
assert openapi_path, f"Path {path} not found in OpenAPI spec"
openapi_path_method = openapi_path.get(method.lower())
assert (
openapi_path_method
), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}"
security = openapi_path_method.get("security")
assert security, f"Security not found for {path!r} {method.lower()!r}"
assert any(
"oidcAuth" in s for s in security
), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'

0 comments on commit 5ff51ca

Please sign in to comment.