Skip to content

Commit

Permalink
Check cluster state before defer Dataproc operators to trigger (apach…
Browse files Browse the repository at this point in the history
…e#36892)

While operating a data proc cluster in deferrable mode, the condition might already be met (created, deleted, updated) before we defer the task into the trigger. This PR intends to check thecluster status before deferring the task to trigger.
---------

Co-authored-by: Pankaj Koti <[email protected]>
  • Loading branch information
Lee-W and pankajkoti authored Jan 22, 2024
1 parent d48985c commit e07a42e
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 24 deletions.
63 changes: 43 additions & 20 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
def execute(self, context: Context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

# Save data required to display extra link no matter what the cluster status will be
project_id = self.project_id or hook.project_id
if project_id:
Expand All @@ -731,6 +732,7 @@ def execute(self, context: Context) -> dict:
project_id=project_id,
region=self.region,
)

try:
# First try to create a new cluster
operation = self._create_cluster(hook)
Expand All @@ -741,17 +743,24 @@ def execute(self, context: Context) -> dict:
self.log.info("Cluster created.")
return Cluster.to_dict(cluster)
else:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
cluster = hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
if cluster.status.state == cluster.status.State.RUNNING:
self.log.info("Cluster created.")
return Cluster.to_dict(cluster)
else:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
except AlreadyExists:
if not self.use_if_exists:
raise
Expand Down Expand Up @@ -1016,6 +1025,16 @@ def execute(self, context: Context) -> None:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
self.log.info("Cluster deleted.")
else:
try:
hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
except NotFound:
self.log.info("Cluster deleted.")
return
except Exception as e:
raise AirflowException(str(e))

end_time: float = time.time() + self.timeout
self.defer(
trigger=DataprocDeleteClusterTrigger(
Expand Down Expand Up @@ -2480,17 +2499,21 @@ def execute(self, context: Context):
if not self.deferrable:
hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation)
else:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
cluster = hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
if cluster.status.state != cluster.status.State.RUNNING:
self.defer(
trigger=DataprocClusterTrigger(
cluster_name=self.cluster_name,
project_id=self.project_id,
region=self.region,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval_seconds=self.polling_interval_seconds,
),
method_name="execute_complete",
)
self.log.info("Updated %s cluster.", self.cluster_name)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
Expand Down
146 changes: 142 additions & 4 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import pytest
from google.api_core.exceptions import AlreadyExists, NotFound
from google.api_core.retry import Retry
from google.cloud.dataproc_v1 import Batch, JobStatus
from google.cloud import dataproc
from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus

from airflow.exceptions import (
AirflowException,
Expand Down Expand Up @@ -579,7 +580,7 @@ def test_build_with_flex_migs(self):
assert CONFIG_WITH_FLEX_MIG == cluster


class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
class TestDataprocCreateClusterOperator(DataprocClusterTestBase):
def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
op = DataprocCreateClusterOperator(
Expand Down Expand Up @@ -883,6 +884,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
cluster = Cluster(
cluster_name="test_cluster",
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
)
mock_hook.return_value.create_cluster.return_value = cluster
mock_hook.return_value.get_cluster.return_value = cluster
operator = DataprocCreateClusterOperator(
task_id=TASK_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
cluster_config=CONFIG,
labels=LABELS,
cluster_name=CLUSTER_NAME,
delete_on_error=True,
metadata=METADATA,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
retry=RETRY,
timeout=TIMEOUT,
deferrable=True,
)

operator.execute(mock.MagicMock())
assert not mock_defer.called

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)

mock_hook.return_value.create_cluster.assert_called_once_with(
region=GCP_REGION,
project_id=GCP_PROJECT,
cluster_config=CONFIG,
request_id=None,
labels=LABELS,
cluster_name=CLUSTER_NAME,
virtual_cluster_config=None,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.assert_not_called()


@pytest.mark.db_test
@pytest.mark.need_serialized_dag
Expand Down Expand Up @@ -1100,6 +1149,47 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
mock_hook.return_value.create_cluster.return_value = None
mock_hook.return_value.get_cluster.side_effect = NotFound("test")
operator = DataprocDeleteClusterOperator(
task_id=TASK_ID,
region=GCP_REGION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
request_id=REQUEST_ID,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

operator.execute(mock.MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)

mock_hook.return_value.delete_cluster.assert_called_once_with(
project_id=GCP_PROJECT,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
cluster_uuid=None,
request_id=REQUEST_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

mock_hook.return_value.wait_for_operation.assert_not_called()
assert not mock_defer.called


class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
Expand Down Expand Up @@ -1240,8 +1330,8 @@ def test_execute_deferrable(self, mock_trigger_hook, mock_hook):
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
@mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer"))
@mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job"))
def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook):
mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
job_status = mock_hook.return_value.get_job.return_value.status
Expand Down Expand Up @@ -1498,6 +1588,54 @@ def test_create_execute_call_defer_method(self, mock_trigger_hook, mock_hook):
assert isinstance(exc.value.trigger, DataprocClusterTrigger)
assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME

@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer):
cluster = Cluster(
cluster_name="test_cluster",
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
)
mock_hook.return_value.update_cluster.return_value = cluster
mock_hook.return_value.get_cluster.return_value = cluster
operator = DataprocUpdateClusterOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
cluster=CLUSTER,
update_mask=UPDATE_MASK,
request_id=REQUEST_ID,
graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
project_id=GCP_PROJECT,
gcp_conn_id=GCP_CONN_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
impersonation_chain=IMPERSONATION_CHAIN,
deferrable=True,
)

operator.execute(mock.MagicMock())

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.update_cluster.assert_called_once_with(
project_id=GCP_PROJECT,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
cluster=CLUSTER,
update_mask=UPDATE_MASK,
request_id=REQUEST_ID,
graceful_decommission_timeout={"graceful_decommission_timeout": "600s"},
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
mock_hook.return_value.wait_for_operation.assert_not_called()
assert not mock_defer.called


@pytest.mark.db_test
@pytest.mark.need_serialized_dag
Expand Down

0 comments on commit e07a42e

Please sign in to comment.