Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
molcay committed Jan 23, 2024
1 parent 66143ef commit 7d07272
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,56 +1151,64 @@ def __init__(
self.metadata = metadata
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self._hook: DataprocHook | None = None

def _get_cluster(self, hook: DataprocHook) -> Cluster:
@property
def hook(self):
if self._hook is None:
self._hook = DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
return self._hook

def _get_project_id(self) -> str:
return self.project_id or self.hook.project_id

def _get_cluster(self) -> Cluster:
"""Retrieve the cluster information.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook`` class
to interact with Dataproc API
:return: Instance of ``google.cloud.dataproc_v1.Cluster``` class
"""
return hook.get_cluster(
project_id=self.project_id,
return self.hook.get_cluster(
project_id=self._get_project_id(),
region=self.region,
cluster_name=self.cluster_name,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
)

def _check_desired_cluster_state(self, hook: Cluster) -> tuple[bool, str | None]:
def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | None]:
"""Implement this method in child class to return whether the cluster is in desired state or not.
If the cluster is in desired stated you can return a log message content as a second value
for the return tuple.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook``
:param cluster: Required. Instance of ``google.cloud.dataproc_v1.Cluster``
class to interact with Dataproc API
:return: Tuple of (Boolean, Optional[str]) The first value of the tuple is whether the cluster is
in desired state or not. The second value of the tuple will use if you want to log something when
the cluster is in desired state already.
"""
raise NotImplementedError

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
def _get_operation(self) -> operation.Operation:
"""Implement this method in child class to call the related hook method and return its result.
:param hook: Required. Instance of ``airflow.providers.google.cloud.hooks.dataproc.DataprocHook``
class to interact with Dataproc API
:return: ``google.api_core.operation.Operation`` value whether the cluster is in desired state or not
"""
raise NotImplementedError

def execute(self, context: Context) -> dict | None:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
cluster: Cluster = self._get_cluster(hook)
cluster: Cluster = self._get_cluster()
is_already_desired_state, log_str = self._check_desired_cluster_state(cluster)
if is_already_desired_state:
self.log.info(log_str)
return None

op: operation.Operation = self._get_operation(hook)
result = hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)
op: operation.Operation = self._get_operation()
result = self.hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=op)
return Cluster.to_dict(result)


Expand All @@ -1216,7 +1224,7 @@ def execute(self, context: Context) -> dict | None:
context=context,
operator=self,
cluster_id=self.cluster_name,
project_id=self.project_id,
project_id=self._get_project_id(),
region=self.region,
)
self.log.info("Cluster started")
Expand All @@ -1227,10 +1235,10 @@ def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | No
return True, f'The cluster "{self.cluster_name}" already running!'
return False, None

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
return hook.start_cluster(
def _get_operation(self) -> operation.Operation:
return self.hook.start_cluster(
region=self.region,
project_id=self.project_id,
project_id=self._get_project_id(),
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
Expand All @@ -1253,10 +1261,10 @@ def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, str | No
return True, f'The cluster "{self.cluster_name}" already stopped!'
return False, None

def _get_operation(self, hook: DataprocHook) -> operation.Operation:
return hook.stop_cluster(
def _get_operation(self) -> operation.Operation:
return self.hook.stop_cluster(
region=self.region,
project_id=self.project_id,
project_id=self._get_project_id(),
cluster_name=self.cluster_name,
cluster_uuid=self.cluster_uuid,
retry=self.retry,
Expand Down

0 comments on commit 7d07272

Please sign in to comment.