diff --git a/tests/datasets/test_manager.py b/tests/datasets/test_manager.py index 25f0d9e52bcfa..3d3f4dca9260e 100644 --- a/tests/datasets/test_manager.py +++ b/tests/datasets/test_manager.py @@ -18,15 +18,18 @@ from __future__ import annotations import itertools +from datetime import datetime from unittest import mock import pytest +from sqlalchemy import delete from airflow.datasets import Dataset from airflow.datasets.manager import DatasetManager from airflow.listeners.listener import get_listener_manager from airflow.models.dag import DagModel from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel +from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic from tests.listeners import dataset_listener pytestmark = pytest.mark.db_test @@ -34,12 +37,47 @@ @pytest.fixture def mock_task_instance(): - mock_ti = mock.Mock() - mock_ti.task_id = "5" - mock_ti.dag_id = "7" - mock_ti.run_id = "11" - mock_ti.map_index = "13" - return mock_ti + return TaskInstancePydantic( + task_id="5", + dag_id="7", + run_id="11", + map_index="13", + start_date=datetime.now(), + end_date=datetime.now(), + execution_date=datetime.now(), + duration=0.1, + state="success", + try_number=1, + max_tries=4, + hostname="host", + unixname="unix", + job_id=13, + pool="default", + pool_slots=1, + queue="default", + priority_weight=77, + operator="DummyOperator", + custom_operator_name="DummyOperator", + queued_dttm=datetime.now(), + queued_by_job_id=3, + pid=12345, + executor="default", + executor_config=None, + updated_at=datetime.now(), + rendered_map_index="1", + external_executor_id="x", + trigger_id=1, + trigger_timeout=datetime.now(), + next_method="bla", + next_kwargs=None, + run_as_user=None, + task=None, + test_mode=False, + dag_run=None, + dag_model=None, + raw=False, + is_trigger_log_context=False, + ) def create_mock_dag(): @@ -77,7 +115,8 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance): dsm = DatasetModel(uri="test_dataset_uri") session.add(dsm) dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)] - session.flush() + session.execute(delete(DatasetDagRunQueue)) + session.commit() dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) @@ -91,7 +130,8 @@ def test_register_dataset_change_no_downstreams(self, session, mock_task_instanc ds = Dataset(uri="never_consumed") dsm = DatasetModel(uri="never_consumed") session.add(dsm) - session.flush() + session.execute(delete(DatasetDagRunQueue)) + session.commit() dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) @@ -99,19 +139,20 @@ def test_register_dataset_change_no_downstreams(self, session, mock_task_instanc assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1 assert session.query(DatasetDagRunQueue).count() == 0 + @pytest.mark.skip_if_database_isolation_mode def test_register_dataset_change_notifies_dataset_listener(self, session, mock_task_instance): dsem = DatasetManager() dataset_listener.clear() get_listener_manager().add_listener(dataset_listener) - ds = Dataset(uri="test_dataset_uri") - dag1 = DagModel(dag_id="dag1") + ds = Dataset(uri="test_dataset_uri_2") + dag1 = DagModel(dag_id="dag3") session.add_all([dag1]) - dsm = DatasetModel(uri="test_dataset_uri") + dsm = DatasetModel(uri="test_dataset_uri_2") session.add(dsm) dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)] - session.flush() + session.commit() dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session) @@ -119,12 +160,13 @@ def test_register_dataset_change_notifies_dataset_listener(self, session, mock_t assert len(dataset_listener.changed) == 1 assert dataset_listener.changed[0].uri == ds.uri + @pytest.mark.skip_if_database_isolation_mode def test_create_datasets_notifies_dataset_listener(self, session): dsem = DatasetManager() dataset_listener.clear() get_listener_manager().add_listener(dataset_listener) - dsm = DatasetModel(uri="test_dataset_uri") + dsm = DatasetModel(uri="test_dataset_uri_3") dsem.create_datasets([dsm], session)