diff --git a/src/stac_auth_proxy/app.py b/src/stac_auth_proxy/app.py index 0233fd6..1893e37 100644 --- a/src/stac_auth_proxy/app.py +++ b/src/stac_auth_proxy/app.py @@ -12,7 +12,7 @@ from .auth import OpenIdConnectAuth from .config import Settings -from .handlers import OpenApiSpecHandler, ReverseProxyHandler +from .handlers import ReverseProxyHandler, build_openapi_spec_handler from .middleware import AddProcessTimeHeaderMiddleware # from .utils import apply_filter @@ -55,7 +55,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: collections_filter=collections_filter, items_filter=items_filter, ) - openapi_handler = OpenApiSpecHandler( + openapi_handler = build_openapi_spec_handler( proxy=proxy_handler, oidc_config_url=str(settings.oidc_discovery_url), ) @@ -67,7 +67,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: ( proxy_handler.stream if path != settings.openapi_spec_endpoint - else openapi_handler.dispatch + else openapi_handler ), methods=methods, dependencies=[Security(auth_scheme.validated_user)], @@ -80,7 +80,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI: ( proxy_handler.stream if path != settings.openapi_spec_endpoint - else openapi_handler.dispatch + else openapi_handler ), methods=methods, dependencies=[Security(auth_scheme.maybe_validated_user)], diff --git a/src/stac_auth_proxy/auth.py b/src/stac_auth_proxy/auth.py index 1e4826e..1c43dea 100644 --- a/src/stac_auth_proxy/auth.py +++ b/src/stac_auth_proxy/auth.py @@ -4,7 +4,7 @@ import logging import urllib.request from dataclasses import dataclass, field -from typing import Annotated, Any, Callable, Optional, Sequence +from typing import Annotated, Optional, Sequence import jwt from fastapi import HTTPException, Security, security, status @@ -25,8 +25,6 @@ class OpenIdConnectAuth: # Generated attributes auth_scheme: SecurityBase = field(init=False) jwks_client: jwt.PyJWKClient = field(init=False) - validated_user: Callable[..., Any] = field(init=False) - maybe_validated_user: Callable[..., Any] = field(init=False) def __post_init__(self): """Initialize the OIDC authentication class.""" @@ -50,70 +48,80 @@ def __post_init__(self): openIdConnectUrl=str(self.openid_configuration_url), auto_error=False, ) - self.validated_user = self._build(auto_error=True) - self.maybe_validated_user = self._build(auto_error=False) - - def _build(self, auto_error: bool = True): - """Build a dependency for validating an OIDC token.""" - - def valid_token_dependency( - auth_header: Annotated[str, Security(self.auth_scheme)], - required_scopes: security.SecurityScopes, - ): - """Dependency to validate an OIDC token.""" - if not auth_header: + + # Update annotations to support FastAPI's dependency injection + for endpoint in [self.validated_user, self.maybe_validated_user]: + endpoint.__annotations__["auth_header"] = Annotated[ + str, + Security(self.auth_scheme), + ] + + def maybe_validated_user( + self, + auth_header: Annotated[str, Security(...)], + required_scopes: security.SecurityScopes, + ): + """Dependency to validate an OIDC token.""" + return self.validated_user(auth_header, required_scopes, auto_error=False) + + def validated_user( + self, + auth_header: Annotated[str, Security(...)], + required_scopes: security.SecurityScopes, + auto_error: bool = True, + ): + """Dependency to validate an OIDC token.""" + if not auth_header: + if auto_error: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authenticated", + ) + return None + + # Extract token from header + token_parts = auth_header.split(" ") + if len(token_parts) != 2 or token_parts[0].lower() != "bearer": + logger.error(f"Invalid token: {auth_header}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + [_, token] = token_parts + + # Parse & validate token + try: + key = self.jwks_client.get_signing_key_from_jwt(token).key + payload = jwt.decode( + token, + key, + algorithms=["RS256"], + # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40) + audience=self.allowed_jwt_audiences, + ) + except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e: + logger.exception(f"InvalidTokenError: {e=}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + # Validate scopes (if required) + for scope in required_scopes.scopes: + if scope not in payload["scope"]: if auto_error: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Not authenticated", + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={ + "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' + }, ) return None - # Extract token from header - token_parts = auth_header.split(" ") - if len(token_parts) != 2 or token_parts[0].lower() != "bearer": - logger.error(f"Invalid token: {auth_header}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - [_, token] = token_parts - - # Parse & validate token - try: - key = self.jwks_client.get_signing_key_from_jwt(token).key - payload = jwt.decode( - token, - key, - algorithms=["RS256"], - # NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40) - audience=self.allowed_jwt_audiences, - ) - except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e: - logger.exception(f"InvalidTokenError: {e=}") - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) from e - - # Validate scopes (if required) - for scope in required_scopes.scopes: - if scope not in payload["scope"]: - if auto_error: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not enough permissions", - headers={ - "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' - }, - ) - return None - - return payload - - return valid_token_dependency + return payload class OidcFetchError(Exception): diff --git a/src/stac_auth_proxy/filters/template.py b/src/stac_auth_proxy/filters/template.py index 45fa2ef..dd6de37 100644 --- a/src/stac_auth_proxy/filters/template.py +++ b/src/stac_auth_proxy/filters/template.py @@ -1,6 +1,6 @@ """Generate CQL2 filter expressions via Jinja2 templating.""" -from typing import Any, Annotated, Callable +from typing import Annotated, Any, Callable from cql2 import Expr from fastapi import Request, Security diff --git a/src/stac_auth_proxy/handlers/__init__.py b/src/stac_auth_proxy/handlers/__init__.py index 43d2dc3..7b03225 100644 --- a/src/stac_auth_proxy/handlers/__init__.py +++ b/src/stac_auth_proxy/handlers/__init__.py @@ -1,6 +1,6 @@ """Handlers to process requests.""" -from .open_api_spec import OpenApiSpecHandler +from .open_api_spec import build_openapi_spec_handler from .reverse_proxy import ReverseProxyHandler -__all__ = ["OpenApiSpecHandler", "ReverseProxyHandler"] +__all__ = ["build_openapi_spec_handler", "ReverseProxyHandler"] diff --git a/src/stac_auth_proxy/handlers/open_api_spec.py b/src/stac_auth_proxy/handlers/open_api_spec.py index 444f1d9..9d3af20 100644 --- a/src/stac_auth_proxy/handlers/open_api_spec.py +++ b/src/stac_auth_proxy/handlers/open_api_spec.py @@ -1,7 +1,6 @@ """Custom request handlers.""" import logging -from dataclasses import dataclass from fastapi import Request, Response from fastapi.routing import APIRoute @@ -12,17 +11,14 @@ logger = logging.getLogger(__name__) -@dataclass -class OpenApiSpecHandler: - """Handler for OpenAPI spec requests.""" - - proxy: ReverseProxyHandler - oidc_config_url: str - auth_scheme_name: str = "oidcAuth" - - async def dispatch(self, req: Request, res: Response): +def build_openapi_spec_handler( + proxy: ReverseProxyHandler, + oidc_config_url: str, + auth_scheme_name: str = "oidcAuth", +): + async def dispatch(req: Request, res: Response): """Proxy the OpenAPI spec from the upstream STAC API, updating it with OIDC security requirements.""" - oidc_spec_response = await self.proxy.proxy_request(req) + oidc_spec_response = await proxy.proxy_request(req) openapi_spec = oidc_spec_response.json() # Pass along the response headers @@ -45,10 +41,10 @@ async def dispatch(self, req: Request, res: Response): # Add the OIDC security scheme to the components openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[ - self.auth_scheme_name + auth_scheme_name ] = { "type": "openIdConnect", - "openIdConnectUrl": self.oidc_config_url, + "openIdConnectUrl": oidc_config_url, } # Update the paths with the specified security requirements @@ -61,9 +57,9 @@ async def dispatch(self, req: Request, res: Response): if match.name != "FULL": continue # Add the OIDC security requirement - config.setdefault("security", []).append( - {self.auth_scheme_name: []} - ) + config.setdefault("security", []).append({auth_scheme_name: []}) break return openapi_spec + + return dispatch