Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce operation helper class and refactor #147

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions providers/src/airflow/providers/google/cloud/hooks/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
PredictResponse,
)

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -65,7 +66,7 @@
"airflow.providers.google.cloud.hooks.translate.TranslateHook",
category=AirflowProviderDeprecationWarning,
)
class CloudAutoMLHook(GoogleBaseHook):
class CloudAutoMLHook(GoogleBaseHook, OperationHelper):
"""
Google Cloud AutoML hook.

Expand Down Expand Up @@ -101,14 +102,6 @@ def get_conn(self) -> AutoMlClient:
self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
return self._client

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@cached_property
def prediction_client(self) -> PredictionServiceClient:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -42,7 +43,7 @@
TIME_TO_SLEEP_IN_SECONDS = 5


class CloudBuildHook(GoogleBaseHook):
class CloudBuildHook(GoogleBaseHook, OperationHelper):
"""
Hook for the Google Cloud Build Service.

Expand Down Expand Up @@ -80,14 +81,6 @@ def _get_build_id_from_operation(self, operation: Operation) -> str:
except Exception:
raise AirflowException("Could not retrieve Build ID from Operation.")

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def get_conn(self, location: str = "global") -> CloudBuildClient:
"""
Retrieve the connection to Google Cloud Build.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@
)
from google.protobuf.field_mask_pb2 import FieldMask

from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

class CloudComposerHook(GoogleBaseHook):

class CloudComposerHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Composer APIs."""

client_options = ClientOptions(api_endpoint="composer.googleapis.com:443")
Expand All @@ -74,14 +76,6 @@ def get_image_versions_client(self) -> ImageVersionsClient:
client_options=self.client_options,
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def get_operation(self, operation_name):
return self.get_environment_client().transport.operations_client.get_operation(name=operation_name)

Expand Down
12 changes: 2 additions & 10 deletions providers/src/airflow/providers/google/cloud/hooks/dataplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
GoogleBaseAsyncHook,
GoogleBaseHook,
)
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.api_core.retry_async import AsyncRetry
from googleapiclient.discovery import Resource
Expand All @@ -60,7 +60,7 @@ class AirflowDataQualityScanResultTimeoutException(AirflowException):
"""Raised when no result found after specified amount of seconds."""


class DataplexHook(GoogleBaseHook):
class DataplexHook(GoogleBaseHook, OperationHelper):
"""
Hook for Google Dataplex.

Expand Down Expand Up @@ -110,14 +110,6 @@ def get_dataplex_data_scan_client(self) -> DataScanServiceClient:
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, timeout: float | None, operation: Operation):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@GoogleBaseHook.fallback_to_default_project_id
def create_task(
self,
Expand Down
13 changes: 2 additions & 11 deletions providers/src/airflow/providers/google/cloud/hooks/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -53,7 +54,6 @@
automl_translation,
)
from google.cloud.translate_v3.types.translation_service import Glossary
from proto import Message


class WaitOperationNotDoneYetError(Exception):
Expand Down Expand Up @@ -155,7 +155,7 @@ def translate(
)


class TranslateHook(GoogleBaseHook):
class TranslateHook(GoogleBaseHook, OperationHelper):
"""
Hook for Google Cloud translation (Advanced) using client version V3.

Expand Down Expand Up @@ -221,15 +221,6 @@ def wait_for_operation_done(
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message:
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except GoogleAPICallError:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_object_id(obj: dict) -> str:
"""Return unique id of the object."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -47,7 +48,7 @@
from google.cloud.aiplatform_v1.types import TrainingPipeline


class AutoMLHook(GoogleBaseHook):
class AutoMLHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Auto ML APIs."""

def __init__(
Expand Down Expand Up @@ -253,14 +254,6 @@ def extract_training_id(resource_name: str) -> str:
"""Return unique id of the Training pipeline."""
return resource_name.rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_auto_ml_job(self) -> None:
"""Cancel Auto ML Job for training pipeline."""
if self._job:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import AsyncRetry, Retry
from google.cloud.aiplatform_v1.services.job_service.pagers import ListBatchPredictionJobsPager


class BatchPredictionJobHook(GoogleBaseHook):
class BatchPredictionJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Batch Prediction Job APIs."""

def __init__(
Expand All @@ -65,14 +66,6 @@ def get_job_service_client(self, region: str | None = None) -> JobServiceClient:
credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_batch_prediction_job_id(obj: dict) -> str:
"""Return unique id of the batch_prediction_job."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.deprecated import deprecated
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -59,7 +60,7 @@
from google.cloud.aiplatform_v1.types import CustomJob, PipelineJob, TrainingPipeline


class CustomJobHook(GoogleBaseHook):
class CustomJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Custom Job APIs."""

def __init__(
Expand Down Expand Up @@ -277,14 +278,6 @@ def extract_custom_job_id_from_training_pipeline(training_pipeline: dict[str, An
"""Return a unique Custom Job id from a serialized TrainingPipeline proto."""
return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_job(self) -> None:
"""Cancel Job for training pipeline."""
if self._job:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1 import DatasetServiceClient

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -42,7 +42,7 @@
from google.protobuf.field_mask_pb2 import FieldMask


class DatasetHook(GoogleBaseHook):
class DatasetHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Dataset APIs."""

def get_dataset_service_client(self, region: str | None = None) -> DatasetServiceClient:
Expand All @@ -56,14 +56,6 @@ def get_dataset_service_client(self, region: str | None = None) -> DatasetServic
credentials=self.get_credentials(), client_info=CLIENT_INFO, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_dataset_id(obj: dict) -> str:
"""Return unique id of the dataset."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1 import EndpointServiceClient

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
Expand All @@ -37,7 +37,7 @@
from google.protobuf.field_mask_pb2 import FieldMask


class EndpointServiceHook(GoogleBaseHook):
class EndpointServiceHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Endpoint Service APIs."""

def get_endpoint_service_client(self, region: str | None = None) -> EndpointServiceClient:
Expand All @@ -51,14 +51,6 @@ def get_endpoint_service_client(self, region: str | None = None) -> EndpointServ
credentials=self.get_credentials(), client_info=self.client_info, client_options=client_options
)

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

@staticmethod
def extract_endpoint_id(obj: dict) -> str:
"""Return unique id of the endpoint."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.operation_helpers import OperationHelper

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import AsyncRetry, Retry
from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager


class HyperparameterTuningJobHook(GoogleBaseHook):
class HyperparameterTuningJobHook(GoogleBaseHook, OperationHelper):
"""Hook for Google Cloud Vertex AI Hyperparameter Tuning Job APIs."""

def __init__(
Expand Down Expand Up @@ -134,14 +135,6 @@ def extract_hyperparameter_tuning_job_id(obj: dict) -> str:
"""Return unique id of the hyperparameter_tuning_job."""
return obj["name"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except Exception:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def cancel_hyperparameter_tuning_job(self) -> None:
"""Cancel HyperparameterTuningJob."""
if self._hyperparameter_tuning_job:
Expand Down
Loading
Loading