Skip to content

Commit

Permalink
Consolidate asset orphanization and activation (apache#43254)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Oct 26, 2024
1 parent a700382 commit 3ee844b
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 148 deletions.
40 changes: 14 additions & 26 deletions airflow/dag_processing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from airflow.assets import Asset, AssetAlias
from airflow.assets.manager import asset_manager
from airflow.models.asset import (
AssetActive,
AssetAliasModel,
AssetModel,
DagScheduleAssetAliasReference,
Expand Down Expand Up @@ -277,7 +276,7 @@ class AssetModelOperation(NamedTuple):
schedule_asset_references: dict[str, list[Asset]]
schedule_asset_alias_references: dict[str, list[AssetAlias]]
outlet_references: dict[str, list[tuple[str, Asset]]]
assets: dict[str, Asset]
assets: dict[tuple[str, str], Asset]
asset_aliases: dict[str, AssetAlias]

@classmethod
Expand All @@ -300,22 +299,25 @@ def collect(cls, dags: dict[str, DAG]) -> Self:
]
for dag_id, dag in dags.items()
},
assets={asset.uri: asset for asset in _find_all_assets(dags.values())},
assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())},
asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())},
)
return coll

def add_assets(self, *, session: Session) -> dict[str, AssetModel]:
def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]:
# Optimization: skip all database calls if no assets were collected.
if not self.assets:
return {}
orm_assets: dict[str, AssetModel] = {
am.uri: am for am in session.scalars(select(AssetModel).where(AssetModel.uri.in_(self.assets)))
orm_assets: dict[tuple[str, str], AssetModel] = {
(am.name, am.uri): am
for am in session.scalars(
select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets))
)
}
orm_assets.update(
(model.uri, model)
((model.name, model.uri), model)
for model in asset_manager.create_assets(
[asset for uri, asset in self.assets.items() if uri not in orm_assets],
[asset for name_uri, asset in self.assets.items() if name_uri not in orm_assets],
session=session,
)
)
Expand All @@ -340,24 +342,10 @@ def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]:
)
return orm_aliases

def add_asset_active_references(self, assets: Collection[AssetModel], *, session: Session) -> None:
existing_entries = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_(AssetActive.name, AssetActive.uri).in_((asset.name, asset.uri) for asset in assets)
)
)
)
session.add_all(
AssetActive.for_asset(asset)
for asset in assets
if (asset.name, asset.uri) not in existing_entries
)

def add_dag_asset_references(
self,
dags: dict[str, DagModel],
assets: dict[str, AssetModel],
assets: dict[tuple[str, str], AssetModel],
*,
session: Session,
) -> None:
Expand All @@ -369,7 +357,7 @@ def add_dag_asset_references(
if not references:
dags[dag_id].schedule_asset_references = []
continue
referenced_asset_ids = {asset.id for asset in (assets[r.uri] for r in references)}
referenced_asset_ids = {asset.id for asset in (assets[r.name, r.uri] for r in references)}
orm_refs = {r.asset_id: r for r in dags[dag_id].schedule_asset_references}
for asset_id, ref in orm_refs.items():
if asset_id not in referenced_asset_ids:
Expand Down Expand Up @@ -409,7 +397,7 @@ def add_dag_asset_alias_references(
def add_task_asset_references(
self,
dags: dict[str, DagModel],
assets: dict[str, AssetModel],
assets: dict[tuple[str, str], AssetModel],
*,
session: Session,
) -> None:
Expand All @@ -423,7 +411,7 @@ def add_task_asset_references(
continue
referenced_outlets = {
(task_id, asset.id)
for task_id, asset in ((task_id, assets[d.uri]) for task_id, d in references)
for task_id, asset in ((task_id, assets[d.name, d.uri]) for task_id, d in references)
}
orm_refs = {(r.task_id, r.asset_id): r for r in dags[dag_id].task_outlet_asset_references}
for key, ref in orm_refs.items():
Expand Down
11 changes: 9 additions & 2 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import TYPE_CHECKING, Generator, Iterable

from setproctitle import setproctitle
from sqlalchemy import delete, event
from sqlalchemy import delete, event, select

from airflow import settings
from airflow.api_internal.internal_api_call import internal_api_call
Expand Down Expand Up @@ -533,7 +533,14 @@ def _validate_task_pools_and_update_dag_warnings(
)
)

stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all())
stored_warnings = set(
session.scalars(
select(DagWarning).where(
DagWarning.dag_id.in_(dag_ids),
DagWarning.warning_type == DagWarningType.NONEXISTENT_POOL,
)
)
)

for warning_to_delete in stored_warnings - warnings:
session.delete(warning_to_delete)
Expand Down
118 changes: 91 additions & 27 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import itertools
import multiprocessing
import operator
import os
import signal
import sys
Expand Down Expand Up @@ -55,6 +56,7 @@
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.stats import Stats
Expand Down Expand Up @@ -1078,7 +1080,7 @@ def _run_scheduler_loop(self) -> None:

timers.call_regular_interval(
conf.getfloat("scheduler", "parsing_cleanup_interval"),
self._orphan_unreferenced_assets,
self._update_asset_orphanage,
)

if self._standalone_dag_processor:
Expand Down Expand Up @@ -2068,44 +2070,106 @@ def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None:
SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
session.flush()

def _get_orphaning_identifier(self, asset: AssetModel) -> tuple[str, str]:
self.log.info("Orphaning unreferenced %s", asset)
return asset.name, asset.uri

@provide_session
def _orphan_unreferenced_assets(self, session: Session = NEW_SESSION) -> None:
def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
"""
Detect orphaned assets and remove their active entry.
Check assets orphanization and update their active entry.
An orphaned asset is no longer referenced in any DAG schedule parameters or task outlets.
An orphaned asset is no longer referenced in any DAG schedule parameters
or task outlets. Active assets (non-orphaned) have entries in AssetActive
and must have unique names and URIs.
"""
orphaned_asset_query = session.scalars(
select(AssetModel)
.join(
DagScheduleAssetReference,
isouter=True,
)
.join(
TaskOutletAssetReference,
isouter=True,
)
# Group assets into orphaned=True and orphaned=False groups.
orphaned = (
(func.count(DagScheduleAssetReference.dag_id) + func.count(TaskOutletAssetReference.dag_id)) == 0
).label("orphaned")
asset_reference_query = session.execute(
select(orphaned, AssetModel)
.outerjoin(DagScheduleAssetReference)
.outerjoin(TaskOutletAssetReference)
.group_by(AssetModel.id)
.where(AssetModel.active.has())
.having(
and_(
func.count(DagScheduleAssetReference.dag_id) == 0,
func.count(TaskOutletAssetReference.dag_id) == 0,
.order_by(orphaned)
)
asset_orphanation: dict[bool, Collection[AssetModel]] = {
orphaned: [asset for _, asset in group]
for orphaned, group in itertools.groupby(asset_reference_query, key=operator.itemgetter(0))
}
self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), session=session)
self._activate_referenced_assets(asset_orphanation.get(False, ()), session=session)

@staticmethod
def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if assets:
session.execute(
delete(AssetActive).where(
tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets))
)
)
Stats.gauge("asset.orphaned", len(assets))

@staticmethod
def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if not assets:
return

active_assets = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets))
)
)
)

orphaning_identifiers = [self._get_orphaning_identifier(asset) for asset in orphaned_asset_query]
active_name_to_uri: dict[str, str] = {name: uri for name, uri in active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in active_assets}

def _generate_dag_warnings(offending: AssetModel, attr: str, value: str) -> Iterator[DagWarning]:
for ref in itertools.chain(offending.consuming_dags, offending.producing_tasks):
yield DagWarning(
dag_id=ref.dag_id,
error_type=DagWarningType.ASSET_CONFLICT,
message=f"Cannot activate asset {offending}; {attr} is already associated to {value!r}",
)

def _activate_assets_generate_warnings() -> Iterator[DagWarning]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
for asset in assets:
if (asset.name, asset.uri) in active_assets:
continue
existing_uri = active_name_to_uri.get(asset.name) or incoming_name_to_uri.get(asset.name)
if existing_uri is not None and existing_uri != asset.uri:
yield from _generate_dag_warnings(asset, "name", existing_uri)
continue
existing_name = active_uri_to_name.get(asset.uri) or incoming_uri_to_name.get(asset.uri)
if existing_name is not None and existing_name != asset.name:
yield from _generate_dag_warnings(asset, "uri", existing_name)
continue
incoming_name_to_uri[asset.name] = asset.uri
incoming_uri_to_name[asset.uri] = asset.name
session.add(AssetActive.for_asset(asset))

warnings_to_have = {w.dag_id: w for w in _activate_assets_generate_warnings()}
session.execute(
delete(AssetActive).where(
tuple_in_condition((AssetActive.name, AssetActive.uri), orphaning_identifiers)
delete(DagWarning).where(
DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
DagWarning.dag_id.not_in(warnings_to_have),
)
)
existing_warned_dag_ids: set[str] = set(
session.scalars(
select(DagWarning.dag_id).where(
DagWarning.warning_type == DagWarningType.ASSET_CONFLICT,
DagWarning.dag_id.not_in(warnings_to_have),
)
)
)
Stats.gauge("asset.orphaned", len(orphaning_identifiers))
for dag_id, warning in warnings_to_have.items():
if dag_id in existing_warned_dag_ids:
session.merge(warning)
continue
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)

def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]:
"""Organize TIs into lists per their respective executor."""
Expand Down
6 changes: 4 additions & 2 deletions airflow/models/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class AssetModel(Base):
created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)

active = relationship("AssetActive", uselist=False, viewonly=True)
active = relationship("AssetActive", uselist=False, viewonly=True, back_populates="asset")

consuming_dags = relationship("DagScheduleAssetReference", back_populates="asset")
producing_tasks = relationship("TaskOutletAssetReference", back_populates="asset")
Expand Down Expand Up @@ -221,7 +221,7 @@ def __hash__(self):
return hash((self.name, self.uri))

def __repr__(self):
return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
return f"{self.__class__.__name__}(name={self.name!r}, uri={self.uri!r}, extra={self.extra!r})"

def to_public(self) -> Asset:
return Asset(name=self.name, uri=self.uri, group=self.group, extra=self.extra)
Expand Down Expand Up @@ -264,6 +264,8 @@ class AssetActive(Base):
nullable=False,
)

asset = relationship("AssetModel", back_populates="active")

__tablename__ = "asset_active"
__table_args__ = (
PrimaryKeyConstraint(name, uri, name="asset_active_pkey"),
Expand Down
1 change: 0 additions & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,7 +2571,6 @@ def bulk_write_to_db(
orm_asset_aliases = asset_op.add_asset_aliases(session=session)
session.flush() # This populates id so we can create fks in later calls.

asset_op.add_asset_active_references(orm_assets.values(), session=session)
asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session)
asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session)
asset_op.add_task_asset_references(orm_dags, orm_assets, session=session)
Expand Down
1 change: 1 addition & 0 deletions airflow/models/dagwarning.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,5 @@ class DagWarningType(str, Enum):
in the DagWarning model.
"""

ASSET_CONFLICT = "asset conflict"
NONEXISTENT_POOL = "non-existent pool"
Loading

0 comments on commit 3ee844b

Please sign in to comment.