From e07a42e69d1ab472c4da991fca5782990607ebe0 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Mon, 22 Jan 2024 14:32:00 +0800 Subject: [PATCH] Check cluster state before defer Dataproc operators to trigger (#36892) While operating a data proc cluster in deferrable mode, the condition might already be met (created, deleted, updated) before we defer the task into the trigger. This PR intends to check thecluster status before deferring the task to trigger. --------- Co-authored-by: Pankaj Koti --- .../google/cloud/operators/dataproc.py | 63 +++++--- .../google/cloud/operators/test_dataproc.py | 146 +++++++++++++++++- 2 files changed, 185 insertions(+), 24 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 306e0dc03d6e3..b14121139d0ad 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -721,6 +721,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: def execute(self, context: Context) -> dict: self.log.info("Creating cluster: %s", self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be project_id = self.project_id or hook.project_id if project_id: @@ -731,6 +732,7 @@ def execute(self, context: Context) -> dict: project_id=project_id, region=self.region, ) + try: # First try to create a new cluster operation = self._create_cluster(hook) @@ -741,17 +743,24 @@ def execute(self, context: Context) -> dict: self.log.info("Cluster created.") return Cluster.to_dict(cluster) else: - self.defer( - trigger=DataprocClusterTrigger( - cluster_name=self.cluster_name, - project_id=self.project_id, - region=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval_seconds=self.polling_interval_seconds, - ), - method_name="execute_complete", + cluster = hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) + if cluster.status.state == cluster.status.State.RUNNING: + self.log.info("Cluster created.") + return Cluster.to_dict(cluster) + else: + self.defer( + trigger=DataprocClusterTrigger( + cluster_name=self.cluster_name, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) except AlreadyExists: if not self.use_if_exists: raise @@ -1016,6 +1025,16 @@ def execute(self, context: Context) -> None: hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) self.log.info("Cluster deleted.") else: + try: + hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + ) + except NotFound: + self.log.info("Cluster deleted.") + return + except Exception as e: + raise AirflowException(str(e)) + end_time: float = time.time() + self.timeout self.defer( trigger=DataprocDeleteClusterTrigger( @@ -2480,17 +2499,21 @@ def execute(self, context: Context): if not self.deferrable: hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) else: - self.defer( - trigger=DataprocClusterTrigger( - cluster_name=self.cluster_name, - project_id=self.project_id, - region=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval_seconds=self.polling_interval_seconds, - ), - method_name="execute_complete", + cluster = hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) + if cluster.status.state != cluster.status.State.RUNNING: + self.defer( + trigger=DataprocClusterTrigger( + cluster_name=self.cluster_name, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) self.log.info("Updated %s cluster.", self.cluster_name) def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 59a9c1008c4a3..00f45ca8b3db0 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -23,7 +23,8 @@ import pytest from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry -from google.cloud.dataproc_v1 import Batch, JobStatus +from google.cloud import dataproc +from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus from airflow.exceptions import ( AirflowException, @@ -579,7 +580,7 @@ def test_build_with_flex_migs(self): assert CONFIG_WITH_FLEX_MIG == cluster -class TestDataprocClusterCreateOperator(DataprocClusterTestBase): +class TestDataprocCreateClusterOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -883,6 +884,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + cluster = Cluster( + cluster_name="test_cluster", + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), + ) + mock_hook.return_value.create_cluster.return_value = cluster + mock_hook.return_value.get_cluster.return_value = cluster + operator = DataprocCreateClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_config=CONFIG, + labels=LABELS, + cluster_name=CLUSTER_NAME, + delete_on_error=True, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + retry=RETRY, + timeout=TIMEOUT, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + assert not mock_defer.called + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.create_cluster.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_config=CONFIG, + request_id=None, + labels=LABELS, + cluster_name=CLUSTER_NAME, + virtual_cluster_config=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.wait_for_operation.assert_not_called() + @pytest.mark.db_test @pytest.mark.need_serialized_dag @@ -1100,6 +1149,47 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + mock_hook.return_value.create_cluster.return_value = None + mock_hook.return_value.get_cluster.side_effect = NotFound("test") + operator = DataprocDeleteClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + request_id=REQUEST_ID, + gcp_conn_id=GCP_CONN_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.delete_cluster.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster_uuid=None, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + mock_hook.return_value.wait_for_operation.assert_not_called() + assert not mock_defer.called + class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1240,8 +1330,8 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook): assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME @mock.patch(DATAPROC_PATH.format("DataprocHook")) - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer") - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job") + @mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job")) def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook): mock_submit_job.return_value.reference.job_id = TEST_JOB_ID job_status = mock_hook.return_value.get_job.return_value.status @@ -1498,6 +1588,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook): assert isinstance(exc.value.trigger, DataprocClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + cluster = Cluster( + cluster_name="test_cluster", + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), + ) + mock_hook.return_value.update_cluster.return_value = cluster + mock_hook.return_value.get_cluster.return_value = cluster + operator = DataprocUpdateClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + request_id=REQUEST_ID, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + project_id=GCP_PROJECT, + gcp_conn_id=GCP_CONN_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.update_cluster.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + request_id=REQUEST_ID, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.wait_for_operation.assert_not_called() + assert not mock_defer.called + @pytest.mark.db_test @pytest.mark.need_serialized_dag