Skip to content

Commit

Permalink
Fix Manager Tests for Dataset Isolation Mode (apache#41143)
Browse files Browse the repository at this point in the history
  • Loading branch information
jscheffl authored Jul 31, 2024
1 parent 206ce3e commit 3a915ae
Showing 1 changed file with 55 additions and 13 deletions.
68 changes: 55 additions & 13 deletions tests/datasets/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,66 @@
from __future__ import annotations

import itertools
from datetime import datetime
from unittest import mock

import pytest
from sqlalchemy import delete

from airflow.datasets import Dataset
from airflow.datasets.manager import DatasetManager
from airflow.listeners.listener import get_listener_manager
from airflow.models.dag import DagModel
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from tests.listeners import dataset_listener

pytestmark = pytest.mark.db_test


@pytest.fixture
def mock_task_instance():
mock_ti = mock.Mock()
mock_ti.task_id = "5"
mock_ti.dag_id = "7"
mock_ti.run_id = "11"
mock_ti.map_index = "13"
return mock_ti
return TaskInstancePydantic(
task_id="5",
dag_id="7",
run_id="11",
map_index="13",
start_date=datetime.now(),
end_date=datetime.now(),
execution_date=datetime.now(),
duration=0.1,
state="success",
try_number=1,
max_tries=4,
hostname="host",
unixname="unix",
job_id=13,
pool="default",
pool_slots=1,
queue="default",
priority_weight=77,
operator="DummyOperator",
custom_operator_name="DummyOperator",
queued_dttm=datetime.now(),
queued_by_job_id=3,
pid=12345,
executor="default",
executor_config=None,
updated_at=datetime.now(),
rendered_map_index="1",
external_executor_id="x",
trigger_id=1,
trigger_timeout=datetime.now(),
next_method="bla",
next_kwargs=None,
run_as_user=None,
task=None,
test_mode=False,
dag_run=None,
dag_model=None,
raw=False,
is_trigger_log_context=False,
)


def create_mock_dag():
Expand Down Expand Up @@ -77,7 +115,8 @@ def test_register_dataset_change(self, session, dag_maker, mock_task_instance):
dsm = DatasetModel(uri="test_dataset_uri")
session.add(dsm)
dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)]
session.flush()
session.execute(delete(DatasetDagRunQueue))
session.commit()

dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)

Expand All @@ -91,40 +130,43 @@ def test_register_dataset_change_no_downstreams(self, session, mock_task_instanc
ds = Dataset(uri="never_consumed")
dsm = DatasetModel(uri="never_consumed")
session.add(dsm)
session.flush()
session.execute(delete(DatasetDagRunQueue))
session.commit()

dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)

# Ensure we've created a dataset
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 0

@pytest.mark.skip_if_database_isolation_mode
def test_register_dataset_change_notifies_dataset_listener(self, session, mock_task_instance):
dsem = DatasetManager()
dataset_listener.clear()
get_listener_manager().add_listener(dataset_listener)

ds = Dataset(uri="test_dataset_uri")
dag1 = DagModel(dag_id="dag1")
ds = Dataset(uri="test_dataset_uri_2")
dag1 = DagModel(dag_id="dag3")
session.add_all([dag1])

dsm = DatasetModel(uri="test_dataset_uri")
dsm = DatasetModel(uri="test_dataset_uri_2")
session.add(dsm)
dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag1.dag_id)]
session.flush()
session.commit()

dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)

# Ensure the listener was notified
assert len(dataset_listener.changed) == 1
assert dataset_listener.changed[0].uri == ds.uri

@pytest.mark.skip_if_database_isolation_mode
def test_create_datasets_notifies_dataset_listener(self, session):
dsem = DatasetManager()
dataset_listener.clear()
get_listener_manager().add_listener(dataset_listener)

dsm = DatasetModel(uri="test_dataset_uri")
dsm = DatasetModel(uri="test_dataset_uri_3")

dsem.create_datasets([dsm], session)

Expand Down

0 comments on commit 3a915ae

Please sign in to comment.