diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py index 427bf8a09615c..939e5bbcac716 100644 --- a/airflow/providers/google/cloud/triggers/dataproc.py +++ b/airflow/providers/google/cloud/triggers/dataproc.py @@ -22,16 +22,22 @@ import asyncio import re import time -from typing import Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.api_core.exceptions import NotFound from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException +from airflow.models.taskinstance import TaskInstance from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.utils.session import provide_session +from airflow.utils.state import TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session class DataprocBaseTrigger(BaseTrigger): @@ -178,6 +184,36 @@ def serialize(self) -> tuple[str, dict[str, Any]]: }, ) + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + async def run(self) -> AsyncIterator[TriggerEvent]: try: while True: @@ -207,7 +243,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: await asyncio.sleep(self.polling_interval_seconds) except asyncio.CancelledError: try: - if self.delete_on_error: + if self.delete_on_error and self.safe_to_cancel(): + self.log.info( + "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in " + "deferred state." + ) self.log.info("Deleting cluster %s.", self.cluster_name) # The synchronous hook is utilized to delete the cluster when a task is cancelled. # This is because the asynchronous hook deletion is not awaited when the trigger task diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py index f41fc3a280283..08294a5ac59d2 100644 --- a/tests/providers/google/cloud/triggers/test_dataproc.py +++ b/tests/providers/google/cloud/triggers/test_dataproc.py @@ -18,7 +18,7 @@ import asyncio import logging -from asyncio import Future +from asyncio import CancelledError, Future, sleep from unittest import mock import pytest @@ -60,6 +60,7 @@ def cluster_trigger(): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=None, polling_interval_seconds=TEST_POLL_INTERVAL, + delete_on_error=True, ) @@ -328,6 +329,38 @@ async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_tri mock_delete_cluster.assert_not_called() + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel") + async def test_cluster_trigger_run_cancelled_not_safe_to_cancel( + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, cluster_trigger + ): + """Test the trigger's cancellation behavior when it is not safe to cancel.""" + mock_safe_to_cancel.return_value = False + cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING)) + future_cluster = asyncio.Future() + future_cluster.set_result(cluster) + mock_get_async_hook.return_value.get_cluster.return_value = future_cluster + + mock_delete_cluster = mock.MagicMock() + mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster + + cluster_trigger.delete_on_error = True + + async_gen = cluster_trigger.run() + task = asyncio.create_task(async_gen.__anext__()) + await sleep(0) + task.cancel() + + try: + await task + except CancelledError: + pass + + assert mock_delete_cluster.call_count == 0 + mock_delete_cluster.assert_not_called() + @pytest.mark.db_test class TestDataprocBatchTrigger: