Skip to content

Commit

Permalink
Fix 'implicitly coercing SELECT object to scalar subquery' in latest …
Browse files Browse the repository at this point in the history
…dag run statement (apache#37505)

* Fix 'implicitly coercing SELECT object to scalar subquery' in latest dag run statement

* Remove redundant print

* remove redundant dag_maker and session in tests

* Beautify test output
  • Loading branch information
Taragolis authored Feb 19, 2024
1 parent aaec842 commit 51bd26b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 35 deletions.
8 changes: 4 additions & 4 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -3081,7 +3081,7 @@ def bulk_write_to_db(
# Skip these queries entirely if no DAGs can be scheduled to save time.
if any(dag.timetable.can_be_scheduled for dag in dags):
# Get the latest automated dag run for each existing dag as a single query (avoid n+1 query)
query = cls._get_latest_runs_query(dags=list(existing_dags.keys()))
query = cls._get_latest_runs_stmt(dags=list(existing_dags.keys()))
latest_runs = {run.dag_id: run for run in session.scalars(query)}

# Get number of active dagruns for all dags we are processing as a single query.
Expand Down Expand Up @@ -3254,9 +3254,9 @@ def bulk_write_to_db(
cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)

@classmethod
def _get_latest_runs_query(cls, dags: list[str]) -> Query:
def _get_latest_runs_stmt(cls, dags: list[str]) -> Select:
"""
Query the database to retrieve the last automated run for each dag.
Build a select statement for retrieve the last automated run for each dag.
:param dags: dags to query
"""
Expand All @@ -3269,7 +3269,7 @@ def _get_latest_runs_query(cls, dags: list[str]) -> Query:
DagRun.dag_id == existing_dag_id,
DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
)
.subquery()
.scalar_subquery()
)
query = select(DagRun).where(
DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq
Expand Down
65 changes: 34 additions & 31 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pickle
import re
import sys
import warnings
import weakref
from contextlib import redirect_stdout
from datetime import timedelta
Expand All @@ -39,6 +40,7 @@
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
from sqlalchemy import inspect
from sqlalchemy.exc import SAWarning

from airflow import settings
from airflow.configuration import conf
Expand Down Expand Up @@ -4148,34 +4150,35 @@ def test_validate_setup_teardown_trigger_rule(self):
dag.validate_setup_teardown()


def test_get_latest_runs_query_one_dag(dag_maker, session):
with dag_maker(dag_id="dag1") as dag1:
...
query = DAG._get_latest_runs_query(dags=[dag1.dag_id])
actual = [x.strip() for x in str(query.compile()).splitlines()]
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected


def test_get_latest_runs_query_two_dags(dag_maker, session):
with dag_maker(dag_id="dag1") as dag1:
...
with dag_maker(dag_id="dag2") as dag2:
...
query = DAG._get_latest_runs_query(dags=[dag1.dag_id, dag2.dag_id])
actual = [x.strip() for x in str(query.compile()).splitlines()]
print("\n".join(actual))
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date",
]
assert actual == expected
def test_statement_latest_runs_one_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected, compiled_stmt


def test_statement_latest_runs_many_dag():
with warnings.catch_warnings():
warnings.simplefilter("error", category=SAWarning)

stmt = DAG._get_latest_runs_stmt(dags=["fake-dag-1", "fake-dag-2"])
compiled_stmt = str(stmt.compile())
actual = [x.strip() for x in compiled_stmt.splitlines()]
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date",
]
assert actual == expected, compiled_stmt

0 comments on commit 51bd26b

Please sign in to comment.