Skip to content

Commit

Permalink
Introduce operation helper class and refactor
Browse files Browse the repository at this point in the history
- Refactor google cloud hooks, that used generic 'wait_for_operation'
  methods.
  • Loading branch information
Oleg Kachur committed Dec 30, 2024
1 parent 088242a commit 7fac8d8
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 125 deletions.
12 changes: 2 additions & 10 deletions providers/src/airflow/providers/google/cloud/hooks/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
PredictResponse,
)

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -58,7 +58,7 @@
from google.protobuf.field_mask_pb2 import FieldMask


class CloudAutoMLHook(GoogleBaseHook):
class CloudAutoMLHook(GoogleBaseHook, OperationHelper):
"""
Google Cloud AutoML hook.
Expand Down Expand Up @@ -94,14 +94,6 @@ def get_conn(self) -> AutoMlClient:
self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
return self._client

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@cached_property
def prediction_client(self) -> PredictionServiceClient:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -42,7 +43,7 @@
TIME_TO_SLEEP_IN_SECONDS = 5


class CloudBuildHook(GoogleBaseHook):
class CloudBuildHook(GoogleBaseHook, OperationHelper):
"""
Hook for the Google Cloud Build Service.
Expand Down Expand Up @@ -80,14 +81,6 @@ def _get_build_id_from_operation(self, operation: Operation) -> str:
except Exception:
raise AirflowException("Could not retrieve Build ID from Operation.")

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def get_conn(self, location: str = "global") -> CloudBuildClient:
"""
Retrieve the connection to Google Cloud Build.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
)
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

class CloudComposerHook(GoogleBaseHook):

class CloudComposerHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Composer APIs."""

client_options = ClientOptions(api_endpoint="composer.googleapis.com:443")
Expand All @@ -74,14 +76,6 @@ def get_image_versions_client(self) -> ImageVersionsClient:
client_options=self.client_options,
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def get_operation(self, operation_name):
return self.get_environment_client().transport.operations_client.get_operation(name=operation_name)

Expand Down
12 changes: 2 additions & 10 deletions providers/src/airflow/providers/google/cloud/hooks/dataplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.api_core.retry_async import AsyncRetry
from googleapiclient.discovery import Resource
Expand All @@ -60,7 +60,7 @@ class AirflowDataQualityScanResultTimeoutException(AirflowException):
"""Raised when no result found after specified amount of seconds."""


class DataplexHook(GoogleBaseHook):
class DataplexHook(GoogleBaseHook, OperationHelper):
"""
Hook for Google Dataplex.
Expand Down Expand Up @@ -110,14 +110,6 @@ def get_dataplex_data_scan_client(self) -> DataScanServiceClient:
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, timeout: float | None, operation: Operation):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@GoogleBaseHook.fallback_to_default_project_id
def create_task(
self,
Expand Down
13 changes: 2 additions & 11 deletions providers/src/airflow/providers/google/cloud/hooks/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -51,7 +52,6 @@
TransliterationConfig,
automl_translation,
)
from proto import Message


class WaitOperationNotDoneYetError(Exception):
Expand Down Expand Up @@ -153,7 +153,7 @@ def translate(
)


class TranslateHook(GoogleBaseHook):
class TranslateHook(GoogleBaseHook, OperationHelper):
"""
Hook for Google Cloud translation (Advanced) using client version V3.
Expand Down Expand Up @@ -219,15 +219,6 @@ def wait_for_operation_done(
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message:
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except GoogleAPICallError:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_object_id(obj: dict) -> str:
"""Return unique id of the object."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -47,7 +48,7 @@
from google.cloud.aiplatform_v1.types import TrainingPipeline


class AutoMLHook(GoogleBaseHook):
class AutoMLHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Auto ML APIs."""

def __init__(
Expand Down Expand Up @@ -253,14 +254,6 @@ def extract_training_id(resource_name: str) -> str:
"""Return unique id of the Training pipeline."""
return resource_name.rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_auto_ml_job(self) -> None:
"""Cancel Auto ML Job for training pipeline."""
if self._job:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import AsyncRetry, Retry
from google.cloud.aiplatform_v1.services.job_service.pagers import ListBatchPredictionJobsPager


class BatchPredictionJobHook(GoogleBaseHook):
class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Batch Prediction Job APIs."""

def __init__(
Expand All @@ -65,14 +66,6 @@ def get_job_service_client(self, region: str | None = None) -> JobServiceClient:
credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_batch_prediction_job_id(obj: dict) -> str:
"""Return unique id of the batch_prediction_job."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -59,7 +60,7 @@
from google.cloud.aiplatform_v1.types import CustomJob, PipelineJob, TrainingPipeline


class CustomJobHook(GoogleBaseHook):
class CustomJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Custom Job APIs."""

def __init__(
Expand Down Expand Up @@ -277,14 +278,6 @@ def extract_custom_job_id_from_training_pipeline(training_pipeline: dict[str, An
"""Return a unique Custom Job id from a serialized TrainingPipeline proto."""
return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_job(self) -> None:
"""Cancel Job for training pipeline."""
if self._job:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1 import DatasetServiceClient

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -42,7 +42,7 @@
from google.protobuf.field_mask_pb2 import FieldMask


class DatasetHook(GoogleBaseHook):
class DatasetHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Dataset APIs."""

def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient:
Expand All @@ -56,14 +56,6 @@ def get_dataset_service_client(self, region: str | None = None) -> DatasetServic
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_dataset_id(obj: dict) -> str:
"""Return unique id of the dataset."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1 import EndpointServiceClient

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -37,7 +37,7 @@
from google.protobuf.field_mask_pb2 import FieldMask


class EndpointServiceHook(GoogleBaseHook):
class EndpointServiceHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Endpoint Service APIs."""

def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient:
Expand All @@ -51,14 +51,6 @@ def get_endpoint_service_client(self, region: str | None = None) -> EndpointServ
credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_endpoint_id(obj: dict) -> str:
"""Return unique id of the endpoint."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import AsyncRetry, Retry
from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager


class HyperparameterTuningJobHook(GoogleBaseHook):
class HyperparameterTuningJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""

def __init__(
Expand Down Expand Up @@ -134,14 +135,6 @@ def extract_hyperparameter_tuning_job_id(obj: dict) -> str:
"""Return unique id of the hyperparameter_tuning_job."""
return obj["name"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_hyperparameter_tuning_job(self) -> None:
"""Cancel HyperparameterTuningJob."""
if self._hyperparameter_tuning_job:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -40,7 +41,7 @@
from google.cloud.aiplatform_v1.types import Model, model_service


class ModelServiceHook(GoogleBaseHook):
class ModelServiceHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Endpoint Service APIs."""

def get_model_service_client(self, region: str | None = None) -> ModelServiceClient:
Expand All @@ -59,14 +60,6 @@ def extract_model_id(obj: dict) -> str:
"""Return unique id of the model."""
return obj["model"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@GoogleBaseHook.fallback_to_default_project_id
def delete_model(
self,
Expand Down
Loading

0 comments on commit 7fac8d8

Please sign in to comment.