Skip to content

Commit

Permalink
Create DataprocStartClusterOperator and DataprocStopClusterOperator
Browse files Browse the repository at this point in the history
Add base class for start-stop cluster operators to fix failed test

Fix typing
  • Loading branch information
molcay committed Jan 23, 2024
1 parent fbd21ed commit 5c6d69d
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 0 deletions.
88 changes: 88 additions & 0 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,94 @@ def update_cluster(
)
return operation

@GoogleBaseHook.fallback_to_default_project_id
def start_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Operation:
"""Start a cluster in a project.
:param region: Cloud Dataproc region to handle the request.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param cluster_name: The cluster name.
:param cluster_uuid: The cluster UUID
:param request_id: A unique id used to identify the request. If the
server receives two *UpdateClusterRequest* requests with the same
ID, the second request will be ignored, and an operation created
for the first one and stored in the backend is returned.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
:return: An instance of ``google.api_core.operation.Operation``
"""
client = self.get_cluster_client(region=region)
return client.start_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"cluster_uuid": cluster_uuid,
"request_id": request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
def stop_cluster(
self,
region: str,
project_id: str,
cluster_name: str,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> Operation:
"""Start a cluster in a project.
:param region: Cloud Dataproc region to handle the request.
:param project_id: Google Cloud project ID that the cluster belongs to.
:param cluster_name: The cluster name.
:param cluster_uuid: The cluster UUID
:param request_id: A unique id used to identify the request. If the
server receives two *UpdateClusterRequest* requests with the same
ID, the second request will be ignored, and an operation created
for the first one and stored in the backend is returned.
:param retry: A retry object used to retry requests. If *None*, requests
will not be retried.
:param timeout: The amount of time, in seconds, to wait for the request
to complete. If *retry* is specified, the timeout applies to each
individual attempt.
:param metadata: Additional metadata that is provided to the method.
:return: An instance of ``google.api_core.operation.Operation``
"""
client = self.get_cluster_client(region=region)
return client.stop_cluster(
request={
"project_id": project_id,
"region": region,
"cluster_name": cluster_name,
"cluster_uuid": cluster_uuid,
"request_id": request_id,
},
retry=retry,
timeout=timeout,
metadata=metadata,
)

@GoogleBaseHook.fallback_to_default_project_id
def create_workflow_template(
self,
Expand Down
189 changes: 189 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,17 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
cluster = self._get_cluster(hook)
return cluster

def _start_cluster(self, hook: DataprocHook):
op: operation.Operation = hook.start_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)
return hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)

def execute(self, context: Context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
Expand Down Expand Up @@ -795,6 +806,9 @@ def execute(self, context: Context) -> dict:
# Create new cluster
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)
elif cluster.status.state == cluster.status.State.STOPPED:
# if the cluster exists and already stopped, then start the cluster
self._start_cluster(hook)

return Cluster.to_dict(cluster)

Expand Down Expand Up @@ -1076,6 +1090,181 @@ def _delete_cluster(self, hook: DataprocHook):
)


class _DataprocStartStopClusterBaseOperator(GoogleCloudBaseOperator):
"""Base class to start or stop a cluster in a project.
:param cluster_name: Required. Name of the cluster to create
:param region: Required. The specified region where the dataproc cluster is created.
:param project_id: Optional. The ID of the Google Cloud project the cluster belongs to.
:param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the RPC should fail
if cluster with specified UUID does not exist.
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``DeleteClusterRequest`` requests with the same id, then the second request will be ignored and the
first ``google.longrunning.Operation`` created and stored in the backend is returned.
:param retry: A retry object used to retry requests. If ``None`` is specified, requests will not be
retried.
:param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
``retry`` is specified, the timeout applies to each individual attempt.
:param metadata: Additional metadata that is provided to the method.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = (
"cluster_name",
"region",
"project_id",
"request_id",
"impersonation_chain",
)

def __init__(
self,
*,
cluster_name: str,
region: str,
project_id: str | None = None,
cluster_uuid: str | None = None,
request_id: str | None = None,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float = 1 * 60 * 60,
metadata: Sequence[tuple[str, str]] = (),
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.region = region
self.cluster_name = cluster_name
self.cluster_uuid = cluster_uuid
self.request_id = request_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def _get_cluster(self, hook: DataprocHook) -> Cluster:
"""Retrieve the cluster information.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook`` class
to interact with Dataproc API
:return: Instance of ``google.cloud.dataproc_v1.Cluster``` class
"""
return hook.get_cluster(
project_id=self.project_id,
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _check_desired_cluster_state(self, hook: Cluster) -> tuple[bool, str | None]:
"""Implement this method in child class to return whether the cluster is in desired state or not.
If the cluster is in desired stated you can return a log message content as a second value
for the return tuple.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook``
class to interact with Dataproc API
:return: Tuple of (Boolean, Optional[str]) The first value of the tuple is whether the cluster is
in desired state or not. The second value of the tuple will use if you want to log something when
the cluster is in desired state already.
"""
raise NotImplementedError

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
"""Implement this method in child class to call the related hook method and return its result.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook``
class to interact with Dataproc API
:return: ``google.api_core.operation.Operation`` value whether the cluster is in desired state or not
"""
raise NotImplementedError

def execute(self, context: Context) -> dict | None:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
cluster: Cluster = self._get_cluster(hook)
is_already_desired_state, log_str = self._check_desired_cluster_state(cluster)
if is_already_desired_state:
self.log.info(log_str)
return None

op: operation.Operation = self._get_operation(hook)
result = hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)
return Cluster.to_dict(result)


class DataprocStartClusterOperator(_DataprocStartStopClusterBaseOperator):
"""Start a cluster in a project."""

operator_extra_links = (DataprocClusterLink(),)

def execute(self, context: Context) -> dict | None:
self.log.info("Starting the cluster: %s", self.cluster_name)
cluster = super().execute(context)
DataprocClusterLink.persist(
context=context,
operator=self,
cluster_id=self.cluster_name,
project_id=self.project_id,
region=self.region,
)
self.log.info("Cluster started")
return cluster

def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
if cluster.status.state == cluster.status.State.RUNNING:
return True, f'The cluster "{self.cluster_name}" already running!'
return False, None

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
return hook.start_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)


class DataprocStopClusterOperator(_DataprocStartStopClusterBaseOperator):
"""Stop a cluster in a project."""

def execute(self, context: Context) -> dict | None:
self.log.info("Stopping the cluster: %s", self.cluster_name)
cluster = super().execute(context)
self.log.info("Cluster stopped")
return cluster

def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
if cluster.status.state in [cluster.status.State.STOPPED, cluster.status.State.STOPPING]:
return True, f'The cluster "{self.cluster_name}" already stopped!'
return False, None

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
return hook.stop_cluster(
region=self.region,
project_id=self.project_id,
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)


class DataprocJobBaseOperator(GoogleCloudBaseOperator):
"""Base class for operators that launch job on DataProc.
Expand Down
24 changes: 24 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,30 @@ You can use deferrable mode for this action in order to run the operator asynchr
:start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
:end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]

Starting a cluster
---------------------------

To start a cluster you can use the
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_start_cluster_operator]
:end-before: [END how_to_cloud_dataproc_start_cluster_operator]

Stopping a cluster
---------------------------

To stop a cluster you can use the
:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:

.. exampleinclude:: /../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_dataproc_stop_cluster_operator]
:end-before: [END how_to_cloud_dataproc_stop_cluster_operator]

Deleting a cluster
------------------

Expand Down
1 change: 1 addition & 0 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator",
"airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator",
"airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator",
}
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,48 @@ def test_update_cluster_missing_region(self, mock_client):
update_mask="update-mask",
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
def test_start_cluster(self, mock_client):
self.hook.start_cluster(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
mock_client.assert_called_once_with(region=GCP_LOCATION)
mock_client.return_value.start_cluster.assert_called_once_with(
request=dict(
project_id=GCP_PROJECT,
region=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
cluster_uuid=None,
request_id=None,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
def test_stop_cluster(self, mock_client):
self.hook.stop_cluster(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
mock_client.assert_called_once_with(region=GCP_LOCATION)
mock_client.return_value.stop_cluster.assert_called_once_with(
request=dict(
project_id=GCP_PROJECT,
region=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
cluster_uuid=None,
request_id=None,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)

@mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
def test_create_workflow_template(self, mock_client):
template = {"test": "test"}
Expand Down
Loading

0 comments on commit 5c6d69d

Please sign in to comment.