Skip to content

Commit

Permalink
openlineage: add unit test for listener hooks on dag run state change…
Browse files Browse the repository at this point in the history
…s. (apache#42554)

openlineage: cover task instance failure in unit tests.

Signed-off-by: Jakub Dardzinski <[email protected]>
  • Loading branch information
JDarDagran authored Oct 1, 2024
1 parent 8c9d251 commit e46365d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tests/dags/test_openlineage_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@


class OpenLineageExecutionOperator(BaseOperator):
def __init__(self, *, stall_amount=0, **kwargs) -> None:
def __init__(self, *, stall_amount=0, fail=False, **kwargs) -> None:
super().__init__(**kwargs)
self.stall_amount = stall_amount
self.fail = fail

def execute(self, context):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(1)
if self.fail:
raise Exception("Failed")

def get_openlineage_facets_on_start(self):
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-start")])
Expand All @@ -43,6 +46,11 @@ def get_openlineage_facets_on_complete(self, task_instance):
time.sleep(self.stall_amount)
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-complete")])

def get_openlineage_facets_on_failure(self, task_instance):
self.log.error("STALL AMOUNT %s", self.stall_amount)
time.sleep(self.stall_amount)
return OperatorLineage(inputs=[Dataset(namespace="test", name="on-failure")])


with DAG(
dag_id="test_openlineage_execution",
Expand All @@ -57,3 +65,5 @@ def get_openlineage_facets_on_complete(self, task_instance):
mid_stall = OpenLineageExecutionOperator(task_id="execute_mid_stall", stall_amount=15)

long_stall = OpenLineageExecutionOperator(task_id="execute_long_stall", stall_amount=30)

fail = OpenLineageExecutionOperator(task_id="execute_fail", fail=True)
11 changes: 11 additions & 0 deletions tests/providers/openlineage/plugins/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ def test_not_stalled_task_emits_proper_lineage(self):
assert has_value_in_events(events, ["inputs", "name"], "on-start")
assert has_value_in_events(events, ["inputs", "name"], "on-complete")

@pytest.mark.db_test
@conf_vars({("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}'})
def test_not_stalled_failing_task_emits_proper_lineage(self):
task_name = "execute_fail"
run_id = "test_failure"
self.setup_job(task_name, run_id)

events = get_sorted_events(tmp_dir)
assert has_value_in_events(events, ["inputs", "name"], "on-start")
assert has_value_in_events(events, ["inputs", "name"], "on-failure")

@conf_vars(
{
("openlineage", "transport"): f'{{"type": "file", "log_file_path": "{listener_path}"}}',
Expand Down
43 changes: 43 additions & 0 deletions tests/providers/openlineage/plugins/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,49 @@ def test_listener_on_dag_run_state_changes_configure_process_pool_size(mock_exec
mock_executor.return_value.submit.assert_called_once()


class MockExecutor:
def __init__(self, *args, **kwargs):
self.submitted = False
self.succeeded = False
self.result = None

def submit(self, fn, /, *args, **kwargs):
self.submitted = True
try:
fn(*args, **kwargs)
self.succeeded = True
except Exception:
pass
return MagicMock()

def shutdown(self, *args, **kwargs):
print("Shutting down")


@pytest.mark.parametrize(
("method", "dag_run_state"),
[
("on_dag_run_running", DagRunState.RUNNING),
("on_dag_run_success", DagRunState.SUCCESS),
("on_dag_run_failed", DagRunState.FAILED),
],
)
@patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.emit")
def test_listener_on_dag_run_state_changes(mock_emit, method, dag_run_state, create_task_instance):
mock_executor = MockExecutor()
ti = create_task_instance(dag_id="dag", task_id="op")
# Change the state explicitly to set end_date following the logic in the method
ti.dag_run.set_state(dag_run_state)
with mock.patch(
"airflow.providers.openlineage.plugins.listener.ProcessPoolExecutor", return_value=mock_executor
):
listener = OpenLineageListener()
getattr(listener, method)(ti.dag_run, None)
assert mock_executor.submitted is True
assert mock_executor.succeeded is True
mock_emit.assert_called_once()


def test_listener_logs_failed_serialization():
listener = OpenLineageListener()
callback_future = Future()
Expand Down

0 comments on commit e46365d

Please sign in to comment.