Skip to content

Commit

Permalink
feat(providers/google): deprecate DataprocUpdateClusterOperatorAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee-W committed Jan 23, 2024
1 parent 919b086 commit 1e0189c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 177 deletions.
107 changes: 13 additions & 94 deletions astronomer/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""This module contains Google Dataproc operators."""
from __future__ import annotations

import time
import warnings
from typing import Any

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.providers.google.cloud.links.dataproc import (
DATAPROC_CLUSTER_LINK,
DATAPROC_JOB_LOG_LINK,
DataprocLink,
)
Expand All @@ -21,7 +19,6 @@
from google.cloud.dataproc_v1 import JobStatus

from astronomer.providers.google.cloud.triggers.dataproc import (
DataprocCreateClusterTrigger,
DataProcSubmitTrigger,
)
from astronomer.providers.utils.typing_compat import Context
Expand Down Expand Up @@ -172,42 +169,9 @@ def execute_complete( # type: ignore[override]

class DataprocUpdateClusterOperatorAsync(DataprocUpdateClusterOperator):
"""
Updates an existing cluster in a Google cloud platform project.
:param region: Required. The Cloud Dataproc region in which to handle the request.
:param project_id: Optional. The ID of the Google Cloud project the cluster belongs to.
:param cluster_name: Required. The cluster name.
:param cluster: Required. The changes to the cluster.
If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.dataproc_v1.types.Cluster`
:param update_mask: Required. Specifies the path, relative to ``Cluster``, of the field to update. For
example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be
specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the
new value. If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.protobuf.field_mask_pb2.FieldMask`
:param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful
decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout
specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and
potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum
allowed timeout is 1 day.
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval: Time in seconds to sleep between checks of cluster status
This class is deprecated.
Please use :class: `~airflow.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperator`
and set `deferrable` param to `True` instead.
"""

def __init__(
Expand All @@ -216,61 +180,16 @@ def __init__(
polling_interval: float = 5.0,
**kwargs: Any,
):
super().__init__(**kwargs)
warnings.warn(
(
"This module is deprecated and will be removed in 2.0.0."
"Please use `airflow.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperator`"
"and set `deferrable` param to `True` instead."
),
DeprecationWarning,
stacklevel=2,
)
super().__init__(deferrable=True, **kwargs)
self.polling_interval = polling_interval
if self.timeout is None:
self.timeout: float = 24 * 60 * 60

def execute(self, context: Context) -> None:
"""Call update cluster API and defer to wait for cluster update to complete"""
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required by extra links no matter what the cluster status will be
DataprocLink.persist(
context=context, task_instance=self, url=DATAPROC_CLUSTER_LINK, resource=self.cluster_name
)
self.log.info("Updating %s cluster.", self.cluster_name)
hook.update_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
cluster=self.cluster,
update_mask=self.update_mask,
graceful_decommission_timeout=self.graceful_decommission_timeout,
request_id=self.request_id,
retry=self.retry,
metadata=self.metadata,
)
cluster = hook.get_cluster(
project_id=self.project_id, region=self.region, cluster_name=self.cluster_name
)
if cluster.status.state == cluster.status.State.RUNNING:
self.log.info("Updated %s cluster.", self.cluster_name)
else:
end_time: float = time.time() + self.timeout

self.defer(
trigger=DataprocCreateClusterTrigger(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
end_time=end_time,
metadata=self.metadata,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
polling_interval=self.polling_interval,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event and event["status"] == "success":
self.log.info("Updated %s cluster.", event["cluster_name"])
return
if event and event["status"] == "error":
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")
95 changes: 12 additions & 83 deletions tests/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from airflow.providers.google.cloud.operators.dataproc import (
DataprocCreateClusterOperator,
DataprocDeleteClusterOperator,
DataprocUpdateClusterOperator,
)
from google.cloud import dataproc
from google.cloud.dataproc_v1 import Cluster, JobStatus
from google.cloud.dataproc_v1 import JobStatus

from astronomer.providers.google.cloud.operators.dataproc import (
DataprocCreateClusterOperatorAsync,
Expand All @@ -16,7 +16,6 @@
DataprocUpdateClusterOperatorAsync,
)
from astronomer.providers.google.cloud.triggers.dataproc import (
DataprocCreateClusterTrigger,
DataProcSubmitTrigger,
)
from tests.utils.airflow_util import create_context
Expand Down Expand Up @@ -136,85 +135,15 @@ def test_dataproc_operator_execute_success_async(self, mock_submit_job):


class TestDataprocUpdateClusterOperatorAsync:
OPERATOR = DataprocUpdateClusterOperatorAsync(
task_id="task-id",
cluster_name="test_cluster",
region=TEST_REGION,
project_id=TEST_PROJECT_ID,
cluster={},
graceful_decommission_timeout=30,
update_mask={},
)

@mock.patch(
"astronomer.providers.google.cloud.operators.dataproc.DataprocUpdateClusterOperatorAsync.defer"
)
@mock.patch("airflow.providers.google.cloud.links.dataproc.DataprocLink.persist")
@mock.patch(f"{MODULE}.get_cluster")
@mock.patch(f"{MODULE}.update_cluster")
def test_dataproc_operator_update_cluster_execute_async_finish_before_defer(
self, mock_update_cluster, mock_get_cluster, mock_persist, mock_defer, context
):
mock_persist.return_value = {}
cluster = Cluster(
cluster_name="test_cluster",
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
)
mock_update_cluster.return_value = cluster
mock_get_cluster.return_value = cluster
DataprocCreateClusterOperatorAsync(
task_id="task-id", cluster_name="test_cluster", region=TEST_REGION, project_id=TEST_PROJECT_ID
)
self.OPERATOR.execute(context)
assert not mock_defer.called

@mock.patch("airflow.providers.google.cloud.links.dataproc.DataprocLink.persist")
@mock.patch(f"{MODULE}.get_cluster")
@mock.patch(f"{MODULE}.update_cluster")
def test_dataproc_operator_update_cluster_execute_async(
self, mock_update_cluster, mock_get_cluster, mock_persist, context
):
"""
Asserts that a task is deferred and a DataprocCreateClusterTrigger will be fired
when the DataprocCreateClusterOperatorAsync is executed.
"""
mock_persist.return_value = {}
cluster = Cluster(
cluster_name="test_cluster",
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING),
)
mock_update_cluster.return_value = cluster
mock_get_cluster.return_value = cluster

with pytest.raises(TaskDeferred) as exc:
self.OPERATOR.execute(context)
assert isinstance(
exc.value.trigger, DataprocCreateClusterTrigger
), "Trigger is not a DataprocCreateClusterTrigger"

def test_dataproc_operator_update_cluster_execute_complete_success(self, context):
"""assert that execute_complete return cluster detail when task succeed"""
cluster = Cluster(
def test_init(self):
task = DataprocUpdateClusterOperatorAsync(
task_id="task-id",
cluster_name="test_cluster",
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.CREATING),
)

assert (
self.OPERATOR.execute_complete(
context=context, event={"status": "success", "data": cluster, "cluster_name": "test_cluster"}
)
is None
region=TEST_REGION,
project_id=TEST_PROJECT_ID,
cluster={},
graceful_decommission_timeout=30,
update_mask={},
)

@pytest.mark.parametrize(
"event",
[
{"status": "error", "message": ""},
None,
],
)
def test_dataproc_operator_update_cluster_execute_complete_fail(self, event, context):
"""assert that execute_complete raise exception when task fail"""

with pytest.raises(AirflowException):
self.OPERATOR.execute_complete(context=context, event=event)
assert isinstance(task, DataprocUpdateClusterOperator)
assert task.deferrable is True

0 comments on commit 1e0189c

Please sign in to comment.