diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d2366c0e9e7a0..164e83a3f5732 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -3081,7 +3081,7 @@ def bulk_write_to_db( # Skip these queries entirely if no DAGs can be scheduled to save time. if any(dag.timetable.can_be_scheduled for dag in dags): # Get the latest automated dag run for each existing dag as a single query (avoid n+1 query) - query = cls._get_latest_runs_query(dags=list(existing_dags.keys())) + query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys())) latest_runs = {run.dag_id: run for run in session.scalars(query)} # Get number of active dagruns for all dags we are processing as a single query. @@ -3254,9 +3254,9 @@ def bulk_write_to_db( cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session) @classmethod - def _get_latest_runs_query(cls, dags: list[str]) -> Query: + def _get_latest_runs_stmt(cls, dags: list[str]) -> Select: """ - Query the database to retrieve the last automated run for each dag. + Build a select statement for retrieve the last automated run for each dag. :param dags: dags to query """ @@ -3269,7 +3269,7 @@ def _get_latest_runs_query(cls, dags: list[str]) -> Query: DagRun.dag_id == existing_dag_id, DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)), ) - .subquery() + .scalar_subquery() ) query = select(DagRun).where( DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 05681cfe8855d..b46f2b28706e7 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -24,6 +24,7 @@ import pickle import re import sys +import warnings import weakref from contextlib import redirect_stdout from datetime import timedelta @@ -39,6 +40,7 @@ from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone from sqlalchemy import inspect +from sqlalchemy.exc import SAWarning from airflow import settings from airflow.configuration import conf @@ -4148,34 +4150,35 @@ def test_validate_setup_teardown_trigger_rule(self): dag.validate_setup_teardown() -def test_get_latest_runs_query_one_dag(dag_maker, session): - with dag_maker(dag_id="dag1") as dag1: - ... - query = DAG._get_latest_runs_query(dags=[dag1.dag_id]) - actual = [x.strip() for x in str(query.compile()).splitlines()] - expected = [ - "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", - "FROM dag_run", - "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date", - "FROM dag_run", - "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", - ] - assert actual == expected - - -def test_get_latest_runs_query_two_dags(dag_maker, session): - with dag_maker(dag_id="dag1") as dag1: - ... - with dag_maker(dag_id="dag2") as dag2: - ... - query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id]) - actual = [x.strip() for x in str(query.compile()).splitlines()] - print("\n".join(actual)) - expected = [ - "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", - "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date", - "FROM dag_run", - "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", - "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date", - ] - assert actual == expected +def test_statement_latest_runs_one_dag(): + with warnings.catch_warnings(): + warnings.simplefilter("error", category=SAWarning) + + stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"]) + compiled_stmt = str(stmt.compile()) + actual = [x.strip() for x in compiled_stmt.splitlines()] + expected = [ + "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))", + ] + assert actual == expected, compiled_stmt + + +def test_statement_latest_runs_many_dag(): + with warnings.catch_warnings(): + warnings.simplefilter("error", category=SAWarning) + + stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"]) + compiled_stmt = str(stmt.compile()) + actual = [x.strip() for x in compiled_stmt.splitlines()] + expected = [ + "SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end", + "FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date", + "FROM dag_run", + "WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1", + "WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date", + ] + assert actual == expected, compiled_stmt