Skip to content

Commit

Permalink
Optimize max_execution_date query in single dag case (apache#33242)
Browse files Browse the repository at this point in the history
We can make better use of an index when we're only dealing with one dag, which is a common case.

---------

Co-authored-by: Elad Kalif <[email protected]>
Co-authored-by: Daniel Standish <[email protected]>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent 2b4da01 commit 10c04a4
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 19 deletions.
69 changes: 50 additions & 19 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
update,
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import backref, joinedload, relationship
from sqlalchemy.orm import backref, joinedload, load_only, relationship
from sqlalchemy.sql import Select, expression

import airflow.templates
Expand Down Expand Up @@ -3062,27 +3062,13 @@ def bulk_write_to_db(
session.add(orm_dag)
orm_dags.append(orm_dag)

dag_id_to_last_automated_run: dict[str, DagRun] = {}
latest_runs: dict[str, DagRun] = {}
num_active_runs: dict[str, int] = {}
# Skip these queries entirely if no DAGs can be scheduled to save time.
if any(dag.timetable.can_be_scheduled for dag in dags):
# Get the latest automated dag run for each existing dag as a single query (avoid n+1 query)
last_automated_runs_subq = (
select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date"))
.where(
DagRun.dag_id.in_(existing_dags),
or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED),
)
.group_by(DagRun.dag_id)
.subquery()
)
last_automated_runs = session.scalars(
select(DagRun).where(
DagRun.dag_id == last_automated_runs_subq.c.dag_id,
DagRun.execution_date == last_automated_runs_subq.c.max_execution_date,
)
)
dag_id_to_last_automated_run = {run.dag_id: run for run in last_automated_runs}
query = cls._get_latest_runs_query(existing_dags, session)
latest_runs = {run.dag_id: run for run in session.scalars(query)}

# Get number of active dagruns for all dags we are processing as a single query.
num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
Expand Down Expand Up @@ -3116,7 +3102,7 @@ def bulk_write_to_db(
orm_dag.timetable_description = dag.timetable.description
orm_dag.processor_subdir = processor_subdir

last_automated_run: DagRun | None = dag_id_to_last_automated_run.get(dag.dag_id)
last_automated_run: DagRun | None = latest_runs.get(dag.dag_id)
if last_automated_run is None:
last_automated_data_interval = None
else:
Expand Down Expand Up @@ -3253,6 +3239,51 @@ def bulk_write_to_db(
for dag in dags:
cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)

@classmethod
def _get_latest_runs_query(cls, dags, session) -> Query:
"""
Query the database to retrieve the last automated run for each dag.
:param dags: dags to query
:param session: sqlalchemy session object
"""
if len(dags) == 1:
# Index optimized fast path to avoid more complicated & slower groupby queryplan
existing_dag_id = list(dags)[0].dag_id
last_automated_runs_subq = (
select(func.max(DagRun.execution_date).label("max_execution_date"))
.where(
DagRun.dag_id == existing_dag_id,
DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
)
.subquery()
)
query = select(DagRun).where(
DagRun.dag_id == existing_dag_id, DagRun.execution_date == last_automated_runs_subq
)
else:
last_automated_runs_subq = (
select(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date"))
.where(
DagRun.dag_id.in_(dags),
DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
)
.group_by(DagRun.dag_id)
.subquery()
)
query = select(DagRun).where(
DagRun.dag_id == last_automated_runs_subq.c.dag_id,
DagRun.execution_date == last_automated_runs_subq.c.max_execution_date,
)
return query.options(
load_only(
DagRun.dag_id,
DagRun.execution_date,
DagRun.data_interval_start,
DagRun.data_interval_end,
)
)

@provide_session
def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION):
"""
Expand Down
86 changes: 86 additions & 0 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,59 @@ def test_bulk_write_to_db(self):
for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None

def test_bulk_write_to_db_single_dag(self):
"""
Test bulk_write_to_db for a single dag using the index optimized query
"""
clear_db_dags()
dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(1)]

with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()}
assert {
("dag-bulk-sync-0", "test-dag"),
} == set(session.query(DagTag.dag_id, DagTag.name).all())

for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None

# Re-sync should do fewer queries
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)

def test_bulk_write_to_db_multiple_dags(self):
"""
Test bulk_write_to_db for multiple dags which does not use the index optimized query
"""
clear_db_dags()
dags = [DAG(f"dag-bulk-sync-{i}", start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(4)]

with assert_queries_count(5):
DAG.bulk_write_to_db(dags)
with create_session() as session:
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
row[0] for row in session.query(DagModel.dag_id).all()
}
assert {
("dag-bulk-sync-0", "test-dag"),
("dag-bulk-sync-1", "test-dag"),
("dag-bulk-sync-2", "test-dag"),
("dag-bulk-sync-3", "test-dag"),
} == set(session.query(DagTag.dag_id, DagTag.name).all())

for row in session.query(DagModel.last_parsed_time).all():
assert row[0] is not None

# Re-sync should do fewer queries
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)
with assert_queries_count(8):
DAG.bulk_write_to_db(dags)

@pytest.mark.parametrize("interval", [None, "@daily"])
def test_bulk_write_to_db_interval_save_runtime(self, interval):
mock_active_runs_of_dags = mock.MagicMock(side_effect=DagRun.active_runs_of_dags)
Expand Down Expand Up @@ -4082,3 +4135,36 @@ def test_validate_setup_teardown_trigger_rule(self):
Exception, match="Setup tasks must be followed with trigger rule ALL_SUCCESS."
):
dag.validate_setup_teardown()


def test_get_latest_runs_query_one_dag(dag_maker, session):
with dag_maker(dag_id="dag1") as dag1:
...
query = DAG._get_latest_runs_query(dags=[dag1], session=session)
actual = [x.strip() for x in str(query.compile()).splitlines()]
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_1 AND dag_run.execution_date = (SELECT max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id = :dag_id_2 AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]))",
]
assert actual == expected


def test_get_latest_runs_query_two_dags(dag_maker, session):
with dag_maker(dag_id="dag1") as dag1:
...
with dag_maker(dag_id="dag2") as dag2:
...
query = DAG._get_latest_runs_query(dags=[dag1, dag2], session=session)
actual = [x.strip() for x in str(query.compile()).splitlines()]
print("\n".join(actual))
expected = [
"SELECT dag_run.id, dag_run.dag_id, dag_run.execution_date, dag_run.data_interval_start, dag_run.data_interval_end",
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.execution_date) AS max_execution_date",
"FROM dag_run",
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.execution_date = anon_1.max_execution_date",
]
assert actual == expected

0 comments on commit 10c04a4

Please sign in to comment.