From cd323e2edc2eb6ec7eaca263ba71cbaec992891b Mon Sep 17 00:00:00 2001 From: Ian Buss Date: Fri, 8 Nov 2024 11:51:08 +0000 Subject: [PATCH] Ensure lifespans of mounted FastAPI sub-apps are called (#43817) --- airflow/api_fastapi/app.py | 15 +++++++++++++++ airflow/api_fastapi/execution_api/app.py | 10 ++++++++++ tests/api_fastapi/test_app.py | 19 +++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 43885724564c1..9ddd97b6cbe04 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI +from starlette.routing import Mount from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views from airflow.api_fastapi.execution_api.app import create_task_execution_api_app @@ -28,6 +30,18 @@ app: FastAPI | None = None +@asynccontextmanager +async def lifespan(app: FastAPI): + async with AsyncExitStack() as stack: + for route in app.routes: + if isinstance(route, Mount) and isinstance(route.app, FastAPI): + await stack.enter_async_context( + route.app.router.lifespan_context(route.app), + ) + app.state.lifespan_called = True + yield + + def create_app(apps: str = "all") -> FastAPI: apps_list = apps.split(",") if apps else ["all"] @@ -36,6 +50,7 @@ def create_app(apps: str = "all") -> FastAPI: description="Airflow API. All endpoints located under ``/public`` can be used safely, are stable and backward compatible. " "Endpoints located under ``/ui`` are dedicated to the UI and are subject to breaking change " "depending on the need of the frontend. Users should not rely on those but use the public ones instead.", + lifespan=lifespan, ) if "core" in apps_list or "all" in apps_list: diff --git a/airflow/api_fastapi/execution_api/app.py b/airflow/api_fastapi/execution_api/app.py index 82c32104adbf7..1751b61bcd54b 100644 --- a/airflow/api_fastapi/execution_api/app.py +++ b/airflow/api_fastapi/execution_api/app.py @@ -17,9 +17,18 @@ from __future__ import annotations +from contextlib import asynccontextmanager + from fastapi import FastAPI +@asynccontextmanager +async def lifespan(app: FastAPI): + """Context manager for the lifespan of the FastAPI app. For now does nothing.""" + app.state.lifespan_called = True + yield + + def create_task_execution_api_app(app: FastAPI) -> FastAPI: """Create FastAPI app for task execution API.""" from airflow.api_fastapi.execution_api.routes import execution_api_router @@ -28,6 +37,7 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI: task_exec_api_app = FastAPI( title="Airflow Task Execution API", description="The private Airflow Task Execution API.", + lifespan=lifespan, ) task_exec_api_app.include_router(execution_api_router) diff --git a/tests/api_fastapi/test_app.py b/tests/api_fastapi/test_app.py index 3dddd827ff444..e18ba6c467519 100644 --- a/tests/api_fastapi/test_app.py +++ b/tests/api_fastapi/test_app.py @@ -19,6 +19,15 @@ from unittest import mock +def test_main_app_lifespan(client): + with client() as test_client: + test_app = test_client.app + + # assert the app was created and lifespan was called + assert test_app + assert test_app.state.lifespan_called, "Lifespan not called on Execution API app." + + @mock.patch("airflow.api_fastapi.app.init_dag_bag") @mock.patch("airflow.api_fastapi.app.init_views") @mock.patch("airflow.api_fastapi.app.init_plugins") @@ -55,6 +64,16 @@ def test_execution_api_app( mock_init_plugins.assert_not_called() +def test_execution_api_app_lifespan(client): + with client(apps="execution") as test_client: + test_app = test_client.app + + # assert the execution app was created and lifespan was called + execution_app = [route.app for route in test_app.router.routes if route.path == "/execution"] + assert execution_app, "Execution API app not found in FastAPI app." + assert execution_app[0].state.lifespan_called, "Lifespan not called on Execution API app." + + @mock.patch("airflow.api_fastapi.app.init_dag_bag") @mock.patch("airflow.api_fastapi.app.init_views") @mock.patch("airflow.api_fastapi.app.init_plugins")