diff --git a/airflow/dag_processing/collection.py b/airflow/dag_processing/collection.py index 034c9c0540125..f27f45dda82e3 100644 --- a/airflow/dag_processing/collection.py +++ b/airflow/dag_processing/collection.py @@ -37,7 +37,6 @@ from airflow.assets import Asset, AssetAlias from airflow.assets.manager import asset_manager from airflow.models.asset import ( - AssetActive, AssetAliasModel, AssetModel, DagScheduleAssetAliasReference, @@ -277,7 +276,7 @@ class AssetModelOperation(NamedTuple): schedule_asset_references: dict[str, list[Asset]] schedule_asset_alias_references: dict[str, list[AssetAlias]] outlet_references: dict[str, list[tuple[str, Asset]]] - assets: dict[str, Asset] + assets: dict[tuple[str, str], Asset] asset_aliases: dict[str, AssetAlias] @classmethod @@ -300,22 +299,25 @@ def collect(cls, dags: dict[str, DAG]) -> Self: ] for dag_id, dag in dags.items() }, - assets={asset.uri: asset for asset in _find_all_assets(dags.values())}, + assets={(asset.name, asset.uri): asset for asset in _find_all_assets(dags.values())}, asset_aliases={alias.name: alias for alias in _find_all_asset_aliases(dags.values())}, ) return coll - def add_assets(self, *, session: Session) -> dict[str, AssetModel]: + def add_assets(self, *, session: Session) -> dict[tuple[str, str], AssetModel]: # Optimization: skip all database calls if no assets were collected. if not self.assets: return {} - orm_assets: dict[str, AssetModel] = { - am.uri: am for am in session.scalars(select(AssetModel).where(AssetModel.uri.in_(self.assets))) + orm_assets: dict[tuple[str, str], AssetModel] = { + (am.name, am.uri): am + for am in session.scalars( + select(AssetModel).where(tuple_(AssetModel.name, AssetModel.uri).in_(self.assets)) + ) } orm_assets.update( - (model.uri, model) + ((model.name, model.uri), model) for model in asset_manager.create_assets( - [asset for uri, asset in self.assets.items() if uri not in orm_assets], + [asset for name_uri, asset in self.assets.items() if name_uri not in orm_assets], session=session, ) ) @@ -340,24 +342,10 @@ def add_asset_aliases(self, *, session: Session) -> dict[str, AssetAliasModel]: ) return orm_aliases - def add_asset_active_references(self, assets: Collection[AssetModel], *, session: Session) -> None: - existing_entries = set( - session.execute( - select(AssetActive.name, AssetActive.uri).where( - tuple_(AssetActive.name, AssetActive.uri).in_((asset.name, asset.uri) for asset in assets) - ) - ) - ) - session.add_all( - AssetActive.for_asset(asset) - for asset in assets - if (asset.name, asset.uri) not in existing_entries - ) - def add_dag_asset_references( self, dags: dict[str, DagModel], - assets: dict[str, AssetModel], + assets: dict[tuple[str, str], AssetModel], *, session: Session, ) -> None: @@ -369,7 +357,7 @@ def add_dag_asset_references( if not references: dags[dag_id].schedule_asset_references = [] continue - referenced_asset_ids = {asset.id for asset in (assets[r.uri] for r in references)} + referenced_asset_ids = {asset.id for asset in (assets[r.name, r.uri] for r in references)} orm_refs = {r.asset_id: r for r in dags[dag_id].schedule_asset_references} for asset_id, ref in orm_refs.items(): if asset_id not in referenced_asset_ids: @@ -409,7 +397,7 @@ def add_dag_asset_alias_references( def add_task_asset_references( self, dags: dict[str, DagModel], - assets: dict[str, AssetModel], + assets: dict[tuple[str, str], AssetModel], *, session: Session, ) -> None: @@ -423,7 +411,7 @@ def add_task_asset_references( continue referenced_outlets = { (task_id, asset.id) - for task_id, asset in ((task_id, assets[d.uri]) for task_id, d in references) + for task_id, asset in ((task_id, assets[d.name, d.uri]) for task_id, d in references) } orm_refs = {(r.task_id, r.asset_id): r for r in dags[dag_id].task_outlet_asset_references} for key, ref in orm_refs.items(): diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py index f030cb75019e5..8694f5890ccd8 100644 --- a/airflow/dag_processing/processor.py +++ b/airflow/dag_processing/processor.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Generator, Iterable from setproctitle import setproctitle -from sqlalchemy import delete, event +from sqlalchemy import delete, event, select from airflow import settings from airflow.api_internal.internal_api_call import internal_api_call @@ -533,7 +533,14 @@ def _validate_task_pools_and_update_dag_warnings( ) ) - stored_warnings = set(session.query(DagWarning).filter(DagWarning.dag_id.in_(dag_ids)).all()) + stored_warnings = set( + session.scalars( + select(DagWarning).where( + DagWarning.dag_id.in_(dag_ids), + DagWarning.warning_type == DagWarningType.NONEXISTENT_POOL, + ) + ) + ) for warning_to_delete in stored_warnings - warnings: session.delete(warning_to_delete) diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 04ea8c5e61675..15042b0d3f177 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -19,6 +19,7 @@ import itertools import multiprocessing +import operator import os import signal import sys @@ -55,6 +56,7 @@ from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun +from airflow.models.dagwarning import DagWarning, DagWarningType from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance from airflow.stats import Stats @@ -1078,7 +1080,7 @@ def _run_scheduler_loop(self) -> None: timers.call_regular_interval( conf.getfloat("scheduler", "parsing_cleanup_interval"), - self._orphan_unreferenced_assets, + self._update_asset_orphanage, ) if self._standalone_dag_processor: @@ -2068,44 +2070,106 @@ def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None: SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session) session.flush() - def _get_orphaning_identifier(self, asset: AssetModel) -> tuple[str, str]: - self.log.info("Orphaning unreferenced %s", asset) - return asset.name, asset.uri - @provide_session - def _orphan_unreferenced_assets(self, session: Session = NEW_SESSION) -> None: + def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None: """ - Detect orphaned assets and remove their active entry. + Check assets orphanization and update their active entry. - An orphaned asset is no longer referenced in any DAG schedule parameters or task outlets. + An orphaned asset is no longer referenced in any DAG schedule parameters + or task outlets. Active assets (non-orphaned) have entries in AssetActive + and must have unique names and URIs. """ - orphaned_asset_query = session.scalars( - select(AssetModel) - .join( - DagScheduleAssetReference, - isouter=True, - ) - .join( - TaskOutletAssetReference, - isouter=True, - ) + # Group assets into orphaned=True and orphaned=False groups. + orphaned = ( + (func.count(DagScheduleAssetReference.dag_id) + func.count(TaskOutletAssetReference.dag_id)) == 0 + ).label("orphaned") + asset_reference_query = session.execute( + select(orphaned, AssetModel) + .outerjoin(DagScheduleAssetReference) + .outerjoin(TaskOutletAssetReference) .group_by(AssetModel.id) - .where(AssetModel.active.has()) - .having( - and_( - func.count(DagScheduleAssetReference.dag_id) == 0, - func.count(TaskOutletAssetReference.dag_id) == 0, + .order_by(orphaned) + ) + asset_orphanation: dict[bool, Collection[AssetModel]] = { + orphaned: [asset for _, asset in group] + for orphaned, group in itertools.groupby(asset_reference_query, key=operator.itemgetter(0)) + } + self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), session=session) + self._activate_referenced_assets(asset_orphanation.get(False, ()), session=session) + + @staticmethod + def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Session) -> None: + if assets: + session.execute( + delete(AssetActive).where( + tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets)) + ) + ) + Stats.gauge("asset.orphaned", len(assets)) + + @staticmethod + def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Session) -> None: + if not assets: + return + + active_assets = set( + session.execute( + select(AssetActive.name, AssetActive.uri).where( + tuple_in_condition((AssetActive.name, AssetActive.uri), ((a.name, a.uri) for a in assets)) ) ) ) - orphaning_identifiers = [self._get_orphaning_identifier(asset) for asset in orphaned_asset_query] + active_name_to_uri: dict[str, str] = {name: uri for name, uri in active_assets} + active_uri_to_name: dict[str, str] = {uri: name for name, uri in active_assets} + + def _generate_dag_warnings(offending: AssetModel, attr: str, value: str) -> Iterator[DagWarning]: + for ref in itertools.chain(offending.consuming_dags, offending.producing_tasks): + yield DagWarning( + dag_id=ref.dag_id, + error_type=DagWarningType.ASSET_CONFLICT, + message=f"Cannot activate asset {offending}; {attr} is already associated to {value!r}", + ) + + def _activate_assets_generate_warnings() -> Iterator[DagWarning]: + incoming_name_to_uri: dict[str, str] = {} + incoming_uri_to_name: dict[str, str] = {} + for asset in assets: + if (asset.name, asset.uri) in active_assets: + continue + existing_uri = active_name_to_uri.get(asset.name) or incoming_name_to_uri.get(asset.name) + if existing_uri is not None and existing_uri != asset.uri: + yield from _generate_dag_warnings(asset, "name", existing_uri) + continue + existing_name = active_uri_to_name.get(asset.uri) or incoming_uri_to_name.get(asset.uri) + if existing_name is not None and existing_name != asset.name: + yield from _generate_dag_warnings(asset, "uri", existing_name) + continue + incoming_name_to_uri[asset.name] = asset.uri + incoming_uri_to_name[asset.uri] = asset.name + session.add(AssetActive.for_asset(asset)) + + warnings_to_have = {w.dag_id: w for w in _activate_assets_generate_warnings()} session.execute( - delete(AssetActive).where( - tuple_in_condition((AssetActive.name, AssetActive.uri), orphaning_identifiers) + delete(DagWarning).where( + DagWarning.warning_type == DagWarningType.ASSET_CONFLICT, + DagWarning.dag_id.not_in(warnings_to_have), + ) + ) + existing_warned_dag_ids: set[str] = set( + session.scalars( + select(DagWarning.dag_id).where( + DagWarning.warning_type == DagWarningType.ASSET_CONFLICT, + DagWarning.dag_id.not_in(warnings_to_have), + ) ) ) - Stats.gauge("asset.orphaned", len(orphaning_identifiers)) + for dag_id, warning in warnings_to_have.items(): + if dag_id in existing_warned_dag_ids: + session.merge(warning) + continue + session.add(warning) + existing_warned_dag_ids.add(warning.dag_id) def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]: """Organize TIs into lists per their respective executor.""" diff --git a/airflow/models/asset.py b/airflow/models/asset.py index fc77cb7a31d6f..d6092aaff1b7c 100644 --- a/airflow/models/asset.py +++ b/airflow/models/asset.py @@ -181,7 +181,7 @@ class AssetModel(Base): created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False) updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False) - active = relationship("AssetActive", uselist=False, viewonly=True) + active = relationship("AssetActive", uselist=False, viewonly=True, back_populates="asset") consuming_dags = relationship("DagScheduleAssetReference", back_populates="asset") producing_tasks = relationship("TaskOutletAssetReference", back_populates="asset") @@ -221,7 +221,7 @@ def __hash__(self): return hash((self.name, self.uri)) def __repr__(self): - return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})" + return f"{self.__class__.__name__}(name={self.name!r}, uri={self.uri!r}, extra={self.extra!r})" def to_public(self) -> Asset: return Asset(name=self.name, uri=self.uri, group=self.group, extra=self.extra) @@ -264,6 +264,8 @@ class AssetActive(Base): nullable=False, ) + asset = relationship("AssetModel", back_populates="active") + __tablename__ = "asset_active" __table_args__ = ( PrimaryKeyConstraint(name, uri, name="asset_active_pkey"), diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 00943ec2ee262..fd1c67debe248 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2571,7 +2571,6 @@ def bulk_write_to_db( orm_asset_aliases = asset_op.add_asset_aliases(session=session) session.flush() # This populates id so we can create fks in later calls. - asset_op.add_asset_active_references(orm_assets.values(), session=session) asset_op.add_dag_asset_references(orm_dags, orm_assets, session=session) asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases, session=session) asset_op.add_task_asset_references(orm_dags, orm_assets, session=session) diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py index ffab515f85495..e0c271c4c8ec2 100644 --- a/airflow/models/dagwarning.py +++ b/airflow/models/dagwarning.py @@ -104,4 +104,5 @@ class DagWarningType(str, Enum): in the DagWarning model. """ + ASSET_CONFLICT = "asset conflict" NONEXISTENT_POOL = "non-existent pool" diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index ac48344435a3b..3d71d5987994d 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -52,7 +52,7 @@ from airflow.jobs.job import Job, run_job from airflow.jobs.local_task_job_runner import LocalTaskJobRunner from airflow.jobs.scheduler_job_runner import SchedulerJobRunner -from airflow.models.asset import AssetDagRunQueue, AssetEvent, AssetModel +from airflow.models.asset import AssetActive, AssetDagRunQueue, AssetEvent, AssetModel from airflow.models.backfill import Backfill, _create_backfill from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag @@ -6160,84 +6160,102 @@ def test_update_dagrun_state_for_paused_dag_not_for_backfill(self, dag_maker, se (backfill_run,) = DagRun.find(dag_id=dag.dag_id, run_type=DagRunType.BACKFILL_JOB, session=session) assert backfill_run.state == State.SUCCESS + @staticmethod + def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel]]: + assets = session.execute( + select(AssetModel, AssetActive) + .outerjoin( + AssetActive, + (AssetModel.name == AssetActive.name) & (AssetModel.uri == AssetActive.uri), + ) + .order_by(AssetModel.uri) + ).all() + return [a for a, v in assets if not v], [a for a, v in assets if v] + + @pytest.mark.want_activate_assets(False) def test_asset_orphaning(self, dag_maker, session): + self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) + asset1 = Asset(uri="ds1") asset2 = Asset(uri="ds2") asset3 = Asset(uri="ds3") asset4 = Asset(uri="ds4") + asset5 = Asset(uri="ds5") with dag_maker(dag_id="assets-1", schedule=[asset1, asset2], session=session): BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) - non_orphaned_asset_count = session.query(AssetModel).filter(AssetModel.active.has()).count() - assert non_orphaned_asset_count == 4 - orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.active.has()).count() - assert orphaned_asset_count == 0 + # Assets not activated yet; asset5 is not even registered (since it's not used anywhere). + orphaned, active = self._find_assets_activation(session) + assert active == [] + assert orphaned == [asset1, asset2, asset3, asset4] - # now remove 2 asset references + self.job_runner._update_asset_orphanage(session=session) + session.flush() + + # Assets are activated after scheduler loop. + orphaned, active = self._find_assets_activation(session) + assert active == [asset1, asset2, asset3, asset4] + assert orphaned == [] + + # Now remove 2 asset references and add asset5. with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): - BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3]) + BashOperator(task_id="task", bash_command="echo 1", outlets=[asset3, asset5]) - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + # The DAG parser finds asset5, but it's not activated yet. + orphaned, active = self._find_assets_activation(session) + assert active == [asset1, asset2, asset3, asset4] + assert orphaned == [asset5] - self.job_runner._orphan_unreferenced_assets(session=session) + self.job_runner._update_asset_orphanage(session=session) session.flush() - # and find the orphans - non_orphaned_assets = [ - asset.uri - for asset in session.query(AssetModel.uri) - .filter(AssetModel.active.has()) - .order_by(AssetModel.uri) - ] - assert non_orphaned_assets == ["ds1", "ds3"] - orphaned_assets = session.scalars( - select(AssetModel.uri).where(~AssetModel.active.has()).order_by(AssetModel.uri) - ).all() - assert orphaned_assets == ["ds2", "ds4"] + # Now we get the updated result. + orphaned, active = self._find_assets_activation(session) + assert active == [asset1, asset3, asset5] + assert orphaned == [asset2, asset4] + @pytest.mark.want_activate_assets(False) def test_asset_orphaning_ignore_orphaned_assets(self, dag_maker, session): + self.job_runner = SchedulerJobRunner(job=Job(), subdir=os.devnull) + asset1 = Asset(uri="ds1") with dag_maker(dag_id="assets-1", schedule=[asset1], session=session): BashOperator(task_id="task", bash_command="echo 1") - non_orphaned_asset_count = session.query(AssetModel).filter(AssetModel.active.has()).count() - assert non_orphaned_asset_count == 1 - orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.active.has()).count() - assert orphaned_asset_count == 0 + orphaned, active = self._find_assets_activation(session) + assert active == [] + assert orphaned == [asset1] + + self.job_runner._update_asset_orphanage(session=session) + session.flush() + + orphaned, active = self._find_assets_activation(session) + assert active == [asset1] + assert orphaned == [] # now remove asset1 reference with dag_maker(dag_id="assets-1", schedule=None, session=session): BashOperator(task_id="task", bash_command="echo 1") - scheduler_job = Job() - self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) - - self.job_runner._orphan_unreferenced_assets(session=session) + self.job_runner._update_asset_orphanage(session=session) session.flush() - orphaned_assets_before_rerun = ( - session.query(AssetModel.updated_at, AssetModel.uri) - .filter(~AssetModel.active.has()) - .order_by(AssetModel.uri) - ) - assert [asset.uri for asset in orphaned_assets_before_rerun] == ["ds1"] - updated_at_timestamps = [asset.updated_at for asset in orphaned_assets_before_rerun] + orphaned, active = self._find_assets_activation(session) + assert active == [] + assert orphaned == [asset1] + updated_at_timestamps = [asset.updated_at for asset in orphaned] # when rerunning we should ignore the already orphaned assets and thus the updated_at timestamp # should remain the same - self.job_runner._orphan_unreferenced_assets(session=session) + self.job_runner._update_asset_orphanage(session=session) session.flush() - orphaned_assets_after_rerun = ( - session.query(AssetModel.updated_at, AssetModel.uri) - .filter(~AssetModel.active.has()) - .order_by(AssetModel.uri) - ) - assert [asset.uri for asset in orphaned_assets_after_rerun] == ["ds1"] - assert updated_at_timestamps == [asset.updated_at for asset in orphaned_assets_after_rerun] + orphaned, active = self._find_assets_activation(session) + assert active == [] + assert orphaned == [asset1] + assert [asset.updated_at for asset in orphaned] == updated_at_timestamps def test_misconfigured_dags_doesnt_crash_scheduler(self, session, dag_maker, caplog): """Test that if dagrun creation throws an exception, the scheduler doesn't crash""" diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 86e499edb3bd7..1d7a69ba84376 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -47,6 +47,7 @@ UnknownExecutorException, ) from airflow.models.asset import ( + AssetActive, AssetAliasModel, AssetDagRunQueue, AssetEvent, @@ -1070,54 +1071,47 @@ def test_bulk_write_to_db_assets(self): .all() ) == {(task_id, dag_id1, asset2_orm.id)} - def test_bulk_write_to_db_unorphan_assets(self): + @staticmethod + def _find_assets_activation(session) -> tuple[list[AssetModel], list[AssetModel]]: + assets = session.execute( + select(AssetModel, AssetActive) + .outerjoin( + AssetActive, + (AssetModel.name == AssetActive.name) & (AssetModel.uri == AssetActive.uri), + ) + .order_by(AssetModel.uri) + ).all() + return [a for a, v in assets if not v], [a for a, v in assets if v] + + def test_bulk_write_to_db_does_not_activate(self, dag_maker, session): """ - Assets can lose their last reference and be orphaned, but then if a reference to them reappears, we - need to un-orphan those assets + Assets are not activated on write, but later in the scheduler by the SchedulerJob. """ - with create_session() as session: - # Create four assets - two that have references and two that are unreferenced and marked as - # orphans - asset1 = Asset(uri="ds1") - asset2 = Asset(uri="ds2") - session.add(AssetModel(uri=asset2.uri)) - asset3 = Asset(uri="ds3") - asset4 = Asset(uri="ds4") - session.add(AssetModel(uri=asset4.uri)) - session.flush() - - dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) - BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) - - DAG.bulk_write_to_db([dag1], session=session) - - # Double check - non_orphaned_assets = [ - asset.uri - for asset in session.query(AssetModel.uri) - .filter(AssetModel.active.has()) - .order_by(AssetModel.uri) - ] - assert non_orphaned_assets == ["ds1", "ds3"] - orphaned_assets = [ - asset.uri - for asset in session.query(AssetModel.uri) - .filter(~AssetModel.active.has()) - .order_by(AssetModel.uri) - ] - assert orphaned_assets == ["ds2", "ds4"] + # Create four assets - two that have references and two that are unreferenced and marked as + # orphans + asset1 = Asset(uri="ds1") + asset2 = Asset(uri="ds2") + asset3 = Asset(uri="ds3") + asset4 = Asset(uri="ds4") - # Now add references to the two unreferenced assets - dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) - BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) + dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1]) + BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3]) + DAG.bulk_write_to_db([dag1], session=session) - DAG.bulk_write_to_db([dag1], session=session) + assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [asset1, asset3] + assert session.scalars(select(AssetActive)).all() == [] - # and count the orphans and non-orphans - non_orphaned_asset_count = session.query(AssetModel).filter(AssetModel.active.has()).count() - assert non_orphaned_asset_count == 4 - orphaned_asset_count = session.query(AssetModel).filter(~AssetModel.active.has()).count() - assert orphaned_asset_count == 0 + dag1 = DAG(dag_id="assets-1", start_date=DEFAULT_DATE, schedule=[asset1, asset2]) + BashOperator(dag=dag1, task_id="task", bash_command="echo 1", outlets=[asset3, asset4]) + DAG.bulk_write_to_db([dag1], session=session) + + assert session.scalars(select(AssetModel).order_by(AssetModel.uri)).all() == [ + asset1, + asset2, + asset3, + asset4, + ] + assert session.scalars(select(AssetActive)).all() == [] def test_bulk_write_to_db_asset_aliases(self): """ diff --git a/tests_common/pytest_plugin.py b/tests_common/pytest_plugin.py index 0c7ed57fea668..18c3779d7fa31 100644 --- a/tests_common/pytest_plugin.py +++ b/tests_common/pytest_plugin.py @@ -404,6 +404,7 @@ def pytest_configure(config: pytest.Config) -> None: config.addinivalue_line( "markers", "need_serialized_dag: mark tests that require dags in serialized form to be present" ) + config.addinivalue_line("markers", "want_activate_assets: mark tests that require assets to be activated") config.addinivalue_line( "markers", "db_test: mark tests that require database to be present", @@ -759,12 +760,14 @@ def dag_maker(request): # and "baked" in to various constants want_serialized = False + want_activate_assets = True # Only has effect if want_serialized=True on Airflow 3. # Allow changing default serialized behaviour with `@pytest.mark.need_serialized_dag` or # `@pytest.mark.need_serialized_dag(False)` - serialized_marker = request.node.get_closest_marker("need_serialized_dag") - if serialized_marker: + if serialized_marker := request.node.get_closest_marker("need_serialized_dag"): (want_serialized,) = serialized_marker.args or (True,) + if serialized_marker := request.node.get_closest_marker("want_activate_assets"): + (want_activate_assets,) = serialized_marker.args or (True,) from airflow.utils.log.logging_mixin import LoggingMixin @@ -802,10 +805,26 @@ def _bag_dag_compat(self, dag): return self.dagbag.bag_dag(dag, root_dag=dag) return self.dagbag.bag_dag(dag) + def _activate_assets(self): + from sqlalchemy import select + + from airflow.jobs.scheduler_job_runner import SchedulerJobRunner + from airflow.models.asset import AssetModel, DagScheduleAssetReference, TaskOutletAssetReference + + assets = self.session.scalars( + select(AssetModel).where( + AssetModel.consuming_dags.any(DagScheduleAssetReference.dag_id == self.dag.dag_id) + | AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id == self.dag.dag_id) + ) + ).all() + SchedulerJobRunner._activate_referenced_assets(assets, session=self.session) + def __exit__(self, type, value, traceback): from airflow.models import DagModel from airflow.models.serialized_dag import SerializedDagModel + from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS + dag = self.dag dag.__exit__(type, value, traceback) if type is not None: @@ -822,6 +841,8 @@ def __exit__(self, type, value, traceback): self.session.merge(self.serialized_model) serialized_dag = self._serialized_dag() self._bag_dag_compat(serialized_dag) + if AIRFLOW_V_3_0_PLUS and self.want_activate_assets: + self._activate_assets() self.session.flush() else: self._bag_dag_compat(self.dag) @@ -887,6 +908,7 @@ def __call__( dag_id="test_dag", schedule=timedelta(days=1), serialized=want_serialized, + activate_assets=want_activate_assets, fileloc=None, processor_subdir=None, session=None, @@ -919,6 +941,7 @@ def __call__( self.dag = DAG(dag_id, schedule=schedule, **self.kwargs) self.dag.fileloc = fileloc or request.module.__file__ self.want_serialized = serialized + self.want_activate_assets = activate_assets self.processor_subdir = processor_subdir return self