From 5ff51caf5bde8c2362435aba4c7483d7897786f4 Mon Sep 17 00:00:00 2001 From: Anthony Lukach Date: Sat, 7 Dec 2024 20:44:28 -0800 Subject: [PATCH] fix: handle deeply nested security dependencies (#14) * Add failing test * bugfix: handle deeply nested security dependencies * bugfix: prevent env pollution * expand tests for complete coverage --- src/stac_auth_proxy/handlers/open_api_spec.py | 31 +++++----- src/stac_auth_proxy/utils.py | 13 +++++ tests/conftest.py | 47 +++++++++++++--- tests/test_openapi.py | 56 ++++++++++++++++++- 4 files changed, 124 insertions(+), 23 deletions(-) diff --git a/src/stac_auth_proxy/handlers/open_api_spec.py b/src/stac_auth_proxy/handlers/open_api_spec.py index e2f3394..444f1d9 100644 --- a/src/stac_auth_proxy/handlers/open_api_spec.py +++ b/src/stac_auth_proxy/handlers/open_api_spec.py @@ -6,7 +6,7 @@ from fastapi import Request, Response from fastapi.routing import APIRoute -from ..utils import safe_headers +from ..utils import has_any_security_requirements, safe_headers from .reverse_proxy import ReverseProxyHandler logger = logging.getLogger(__name__) @@ -28,26 +28,29 @@ async def dispatch(self, req: Request, res: Response): # Pass along the response headers res.headers.update(safe_headers(oidc_spec_response.headers)) - # Add the OIDC security scheme to the components - openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[ - self.auth_scheme_name - ] = { - "type": "openIdConnect", - "openIdConnectUrl": self.oidc_config_url, - } - proxy_auth_routes = [ r for r in req.app.routes # Ignore non-APIRoutes (we can't check their security dependencies) if isinstance(r, APIRoute) # Ignore routes that don't have security requirements - and ( - r.dependant.security_requirements - or any(d.security_requirements for d in r.dependant.dependencies) - ) + and has_any_security_requirements(r.dependant) ] + if not proxy_auth_routes: + logger.warning( + "No routes with security requirements found. OIDC security requirements will not be added." + ) + return openapi_spec + + # Add the OIDC security scheme to the components + openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[ + self.auth_scheme_name + ] = { + "type": "openIdConnect", + "openIdConnectUrl": self.oidc_config_url, + } + # Update the paths with the specified security requirements for path, method_config in openapi_spec["paths"].items(): for method, config in method_config.items(): @@ -59,7 +62,7 @@ async def dispatch(self, req: Request, res: Response): continue # Add the OIDC security requirement config.setdefault("security", []).append( - [{self.auth_scheme_name: []}] + {self.auth_scheme_name: []} ) break diff --git a/src/stac_auth_proxy/utils.py b/src/stac_auth_proxy/utils.py index c28c89b..8f047c5 100644 --- a/src/stac_auth_proxy/utils.py +++ b/src/stac_auth_proxy/utils.py @@ -3,6 +3,7 @@ import re from urllib.parse import urlparse +from fastapi.dependencies.models import Dependant from httpx import Headers @@ -29,3 +30,15 @@ def extract_variables(url: str) -> dict: pattern = r"^/collections/(?P[^/]+)(?:/(?:items|bulk_items)(?:/(?P[^/]+))?)?/?$" match = re.match(pattern, path) return {k: v for k, v in match.groupdict().items() if v} if match else {} + + +def has_any_security_requirements(dependency: Dependant) -> bool: + """ + Recursively check if any dependency within the hierarchy has a non-empty + security_requirements list. + """ + if dependency.security_requirements: + return True + return any( + has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies + ) diff --git a/tests/conftest.py b/tests/conftest.py index 3c4ed7f..bdf013e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """Pytest fixtures.""" import json +import os import threading from typing import Any from unittest.mock import MagicMock, patch @@ -68,19 +69,42 @@ def source_api(): app = FastAPI(docs_url="/api.html", openapi_url="/api") for path, methods in { - "/": ["GET"], - "/conformance": ["GET"], - "/queryables": ["GET"], - "/search": ["GET", "POST"], - "/collections": ["GET", "POST"], - "/collections/{collection_id}": ["GET", "PUT", "DELETE"], - "/collections/{collection_id}/items": ["GET", "POST"], + "/": [ + "GET", + ], + "/conformance": [ + "GET", + ], + "/queryables": [ + "GET", + ], + "/search": [ + "GET", + "POST", + ], + "/collections": [ + "GET", + "POST", + ], + "/collections/{collection_id}": [ + "GET", + "PUT", + "PATCH", + "DELETE", + ], + "/collections/{collection_id}/items": [ + "GET", + "POST", + ], "/collections/{collection_id}/items/{item_id}": [ "GET", "PUT", + "PATCH", "DELETE", ], - "/collections/{collection_id}/bulk_items": ["POST"], + "/collections/{collection_id}/bulk_items": [ + "POST", + ], }.items(): for method in methods: # NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function " @@ -109,3 +133,10 @@ def source_api_server(source_api): yield f"http://{host}:{port}" server.should_exit = True thread.join() + + +@pytest.fixture(autouse=True, scope="module") +def mock_env(): + """Clear environment variables to avoid poluting configs from runtime env.""" + with patch.dict(os.environ, clear=True): + yield diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 4f579df..70ea61c 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -9,7 +9,7 @@ ) -def test_no_edit_openapi_spec(source_api_server): +def test_no_openapi_spec_endpoint(source_api_server): """When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered.""" app = app_factory( upstream_url=source_api_server, @@ -25,6 +25,24 @@ def test_no_edit_openapi_spec(source_api_server): assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {}) +def test_no_private_endpoints(source_api_server): + """When no endpoints are private, the proxied OpenAPI spec is unaltered.""" + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint="/api", + private_endpoints={}, + default_public=True, + ) + client = TestClient(app) + response = client.get("/api") + assert response.status_code == 200 + openapi = response.json() + assert "info" in openapi + assert "openapi" in openapi + assert "paths" in openapi + assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {}) + + def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str): """When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details.""" app = app_factory( @@ -39,3 +57,39 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str): assert "openapi" in openapi assert "paths" in openapi assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {}) + + +def test_oidc_in_openapi_spec_private_endpoints( + source_api: FastAPI, source_api_server: str +): + """When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec.""" + private_endpoints = { + # https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods + "/collections": ["POST"], + "/collections/{collection_id}": ["PUT", "PATCH", "DELETE"], + # https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods + "/collections/{collection_id}/items": ["POST"], + "/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"], + # https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension + "/collections/{collection_id}/bulk_items": ["POST"], + } + app = app_factory( + upstream_url=source_api_server, + openapi_spec_endpoint=source_api.openapi_url, + private_endpoints=private_endpoints, + ) + client = TestClient(app) + openapi = client.get(source_api.openapi_url).raise_for_status().json() + for path, methods in private_endpoints.items(): + for method in methods: + openapi_path = openapi["paths"].get(path) + assert openapi_path, f"Path {path} not found in OpenAPI spec" + openapi_path_method = openapi_path.get(method.lower()) + assert ( + openapi_path_method + ), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}" + security = openapi_path_method.get("security") + assert security, f"Security not found for {path!r} {method.lower()!r}" + assert any( + "oidcAuth" in s for s in security + ), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'