diff --git a/docs/apache-airflow-providers-google/operators/cloud/alloy_db.rst b/docs/apache-airflow-providers-google/operators/cloud/alloy_db.rst new file mode 100644 index 0000000000000..7385bb8d0be81 --- /dev/null +++ b/docs/apache-airflow-providers-google/operators/cloud/alloy_db.rst @@ -0,0 +1,73 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Google Cloud AlloyDB Operators +=============================== + +The `AlloyDB for PostgreSQL `__ +is a fully managed, PostgreSQL-compatible database service that's designed for your most demanding workloads, +including hybrid transactional and analytical processing. AlloyDB pairs a Google-built database engine with a +cloud-based, multi-node architecture to deliver enterprise-grade performance, reliability, and availability. + +Airflow provides operators to manage AlloyDB clusters. + +Prerequisite Tasks +^^^^^^^^^^^^^^^^^^ + +.. include:: /operators/_partials/prerequisite_tasks.rst + +.. _howto/operator:AlloyDBCreateClusterOperator: + +Create cluster +"""""""""""""" + +To create an AlloyDB cluster (primary end secondary) you can use +:class:`~airflow.providers.google.cloud.operators.alloy_db.AlloyDBCreateClusterOperator`. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/alloy_db/example_alloy_db.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_alloy_db_create_cluster] + :end-before: [END howto_operator_alloy_db_create_cluster] + +.. _howto/operator:AlloyDBUpdateClusterOperator: + +Update cluster +"""""""""""""" + +To update an AlloyDB cluster you can use +:class:`~airflow.providers.google.cloud.operators.alloy_db.AlloyDBUpdateClusterOperator`. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/alloy_db/example_alloy_db.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_alloy_db_update_cluster] + :end-before: [END howto_operator_alloy_db_update_cluster] + +.. _howto/operator:AlloyDBDeleteClusterOperator: + +Delete cluster +"""""""""""""" + +To delete an AlloyDB cluster you can use +:class:`~airflow.providers.google.cloud.operators.alloy_db.AlloyDBDeleteClusterOperator`. + +.. exampleinclude:: /../../providers/tests/system/google/cloud/alloy_db/example_alloy_db.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_alloy_db_delete_cluster] + :end-before: [END howto_operator_alloy_db_delete_cluster] diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index fe11c5804eb42..d9867d717d01d 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -647,6 +647,7 @@ "google-auth-httplib2>=0.0.1", "google-auth>=2.29.0", "google-cloud-aiplatform>=1.70.0", + "google-cloud-alloydb", "google-cloud-automl>=2.12.0", "google-cloud-batch>=0.13.0", "google-cloud-bigquery-datatransfer>=3.13.0", diff --git a/providers/src/airflow/providers/google/cloud/hooks/alloy_db.py b/providers/src/airflow/providers/google/cloud/hooks/alloy_db.py new file mode 100644 index 0000000000000..3499ad4841091 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/hooks/alloy_db.py @@ -0,0 +1,301 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Module contains a Google Alloy DB Hook. + +.. spelling:word-list:: + + etag +""" + +from __future__ import annotations + +from collections.abc import Sequence +from copy import deepcopy +from typing import TYPE_CHECKING + +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud import alloydb_v1 + +from airflow.exceptions import AirflowException +from airflow.providers.google.common.consts import CLIENT_INFO +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook + +if TYPE_CHECKING: + import proto + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.protobuf.field_mask_pb2 import FieldMask + + +class AlloyDbHook(GoogleBaseHook): + """Google Alloy DB Hook.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._client: alloydb_v1.AlloyDBAdminClient | None = None + + def get_conn(self) -> alloydb_v1.AlloyDBAdminClient: + """Retrieve AlloyDB client.""" + if not self._client: + self._client = alloydb_v1.AlloyDBAdminClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO + ) + return self._client + + def wait_for_operation(self, timeout: float | None, operation: Operation) -> proto.Message: + """Wait for long-lasting operation to complete.""" + self.log.info("Waiting for operation to complete...") + _timeout: int | None = int(timeout) if timeout else None + try: + return operation.result(timeout=_timeout) + except Exception: + error = operation.exception(timeout=_timeout) + raise AirflowException(error) + + @staticmethod + def cluster_name(project_id, location, cluster_id): + return f"projects/{project_id}/locations/{location}/clusters/{cluster_id}" + + @GoogleBaseHook.fallback_to_default_project_id + def create_cluster( + self, + cluster_id: str, + cluster: alloydb_v1.Cluster | dict, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + validate_only: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Create an Alloy DB cluster. + + .. seealso:: + For more details see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.CreateClusterRequest + + :param cluster_id: Required. ID of the cluster to create. + :param cluster: Required. Cluster to create. For more details please see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.Cluster + :param location: Required. The ID of the Google Cloud region where the cluster is located. + :param project_id: Optional. The ID of the Google Cloud project where the cluster is located. + :param request_id: Optional. The ID of an existing request object. + :param validate_only: Optional. If set, performs request validation, but does not actually execute + the create request. + :param retry: Optional. Designation of what errors, if any, should be retried. + :param timeout: Optional. The timeout for this request. + :param metadata: Optional. Strings which should be sent along with the request as metadata. + """ + client = self.get_conn() + return client.create_cluster( + request={ + "parent": f"projects/{project_id}/locations/{location}", + "cluster_id": cluster_id, + "cluster": cluster, + "request_id": request_id, + "validate_only": validate_only, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def create_secondary_cluster( + self, + cluster_id: str, + cluster: alloydb_v1.Cluster | dict, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + validate_only: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Create a secondary Alloy DB cluster. + + .. seealso:: + For more details see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.CreateClusterRequest + + :param cluster_id: Required. ID of the cluster to create. + :param cluster: Required. Cluster to create. For more details please see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.Cluster + :param location: Required. The ID of the Google Cloud region where the cluster is located. + :param project_id: Optional. The ID of the Google Cloud project where the cluster is located. + :param request_id: Optional. The ID of an existing request object. + :param validate_only: Optional. If set, performs request validation, but does not actually execute + the create request. + :param retry: Optional. Designation of what errors, if any, should be retried. + :param timeout: Optional. The timeout for this request. + :param metadata: Optional. Strings which should be sent along with the request as metadata. + """ + client = self.get_conn() + return client.create_secondary_cluster( + request={ + "parent": f"projects/{project_id}/locations/{location}", + "cluster_id": cluster_id, + "cluster": cluster, + "request_id": request_id, + "validate_only": validate_only, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def get_cluster( + self, + cluster_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> alloydb_v1.Cluster: + """ + Retrieve an Alloy DB cluster. + + .. seealso:: + For more details see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.GetClusterRequest + + :param cluster_id: Required. ID of the cluster to create. + :param location: Required. The ID of the Google Cloud region where the cluster is located. + :param project_id: Optional. The ID of the Google Cloud project where the cluster is located. + :param retry: Optional. Designation of what errors, if any, should be retried. + :param timeout: Optional. The timeout for this request. + :param metadata: Optional. Strings which should be sent along with the request as metadata. + """ + client = self.get_conn() + return client.get_cluster( + request={ + "name": self.cluster_name(project_id, location, cluster_id), + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def update_cluster( + self, + cluster_id: str, + cluster: alloydb_v1.Cluster | dict, + location: str, + update_mask: FieldMask | dict | None = None, + project_id: str = PROVIDE_PROJECT_ID, + allow_missing: bool = False, + request_id: str | None = None, + validate_only: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Update an Alloy DB cluster. + + .. seealso:: + For more details see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.UpdateClusterRequest + + :param cluster_id: Required. ID of the cluster to update. + :param cluster: Required. Cluster to create. For more details please see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.Cluster + :param location: Required. The ID of the Google Cloud region where the cluster is located. + :param update_mask: Optional. Field mask is used to specify the fields to be overwritten in the + Cluster resource by the update. + :param request_id: Optional. The ID of an existing request object. + :param validate_only: Optional. If set, performs request validation, but does not actually execute + the create request. + :param project_id: Optional. The ID of the Google Cloud project where the cluster is located. + :param allow_missing: Optional. If set to true, update succeeds even if cluster is not found. + In that case, a new cluster is created and update_mask is ignored. + :param retry: Optional. Designation of what errors, if any, should be retried. + :param timeout: Optional. The timeout for this request. + :param metadata: Optional. Strings which should be sent along with the request as metadata. + """ + client = self.get_conn() + _cluster = deepcopy(cluster) if isinstance(cluster, dict) else alloydb_v1.Cluster.to_dict(cluster) + _cluster["name"] = self.cluster_name(project_id, location, cluster_id) + return client.update_cluster( + request={ + "update_mask": update_mask, + "cluster": _cluster, + "request_id": request_id, + "validate_only": validate_only, + "allow_missing": allow_missing, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + @GoogleBaseHook.fallback_to_default_project_id + def delete_cluster( + self, + cluster_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + etag: str | None = None, + validate_only: bool = False, + force: bool = False, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete an Alloy DB cluster. + + .. seealso:: + For more details see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.DeleteClusterRequest + + :param cluster_id: Required. ID of the cluster to delete. + :param location: Required. The ID of the Google Cloud region where the cluster is located. + :param project_id: Optional. The ID of the Google Cloud project where the cluster is located. + :param request_id: Optional. The ID of an existing request object. + :param etag: Optional. The current etag of the Cluster. If an etag is provided and does not match the + current etag of the Cluster, deletion will be blocked and an ABORTED error will be returned. + :param validate_only: Optional. If set, performs request validation, but does not actually execute + the create request. + :param force: Optional. Whether to cascade delete child instances for given cluster. + :param retry: Optional. Designation of what errors, if any, should be retried. + :param timeout: Optional. The timeout for this request. + :param metadata: Optional. Strings which should be sent along with the request as metadata. + """ + client = self.get_conn() + return client.delete_cluster( + request={ + "name": self.cluster_name(project_id, location, cluster_id), + "request_id": request_id, + "etag": etag, + "validate_only": validate_only, + "force": force, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) diff --git a/providers/src/airflow/providers/google/cloud/links/alloy_db.py b/providers/src/airflow/providers/google/cloud/links/alloy_db.py new file mode 100644 index 0000000000000..6b4c394a67cdd --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/links/alloy_db.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Cloud AlloyDB links.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.models import BaseOperator + from airflow.utils.context import Context + +ALLOY_DB_BASE_LINK = "/alloydb" +ALLOY_DB_CLUSTER_LINK = ( + ALLOY_DB_BASE_LINK + "/locations/{location_id}/clusters/{cluster_id}?project={project_id}" +) + + +class AlloyDBClusterLink(BaseGoogleLink): + """Helper class for constructing AlloyDB cluster Link.""" + + name = "AlloyDB Cluster" + key = "alloy_db_cluster" + format_str = ALLOY_DB_CLUSTER_LINK + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + location_id: str, + cluster_id: str, + project_id: str | None, + ): + task_instance.xcom_push( + context, + key=AlloyDBClusterLink.key, + value={"location_id": location_id, "cluster_id": cluster_id, "project_id": project_id}, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/alloy_db.py b/providers/src/airflow/providers/google/cloud/operators/alloy_db.py new file mode 100644 index 0000000000000..a8dd12a006e35 --- /dev/null +++ b/providers/src/airflow/providers/google/cloud/operators/alloy_db.py @@ -0,0 +1,465 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module contains Google Cloud Alloy DB operators. + +.. spelling:word-list:: + + etag +""" + +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import AlreadyExists, InvalidArgument +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.cloud import alloydb_v1 + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.alloy_db import AlloyDbHook +from airflow.providers.google.cloud.links.alloy_db import AlloyDBClusterLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + import proto + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.protobuf.field_mask_pb2 import FieldMask + + from airflow.utils.context import Context + + +class AlloyDBBaseOperator(GoogleCloudBaseOperator): + """ + Base class for all AlloyDB operators. + + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "project_id", + "location", + "gcp_conn_id", + ) + + def __init__( + self, + project_id: str, + location: str, + gcp_conn_id: str = "google_cloud_default", + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.project_id = project_id + self.location = location + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + @cached_property + def hook(self) -> AlloyDbHook: + return AlloyDbHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class AlloyDBWriteBaseOperator(AlloyDBBaseOperator): + """ + Base class for writing AlloyDB operators. + + These operators perform create, update or delete operations. with the objects (not inside of database). + + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + so that if you must retry your request, the server ignores the request if it has already been + completed. The server guarantees that for at least 60 minutes since the first request. + For example, consider a situation where you make an initial request and the request times out. + If you make the request again with the same request ID, the server can check if the original operation + with the same request ID was received, and if so, ignores the second request. + This prevents clients from accidentally creating duplicate commitments. + The request ID must be a valid UUID with the exception that zero UUID is not supported + (00000000-0000-0000-0000-000000000000). + :param validate_request: Optional. If set, performs request validation, but does not actually + execute the request. + """ + + template_fields: Sequence[str] = tuple( + {"request_id", "validate_request"} | set(AlloyDBBaseOperator.template_fields) + ) + + def __init__( + self, + request_id: str | None = None, + validate_request: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.request_id = request_id + self.validate_request = validate_request + + def get_operation_result(self, operation: Operation) -> proto.Message | None: + """ + Retrieve operation result as a proto.Message. + + If the `validate_request` parameter is set, then no operation is performed and thus nothing to wait. + """ + if self.validate_request: + self.log.info("The request validation has been passed successfully!") + else: + return self.hook.wait_for_operation(timeout=self.timeout, operation=operation) + return None + + +class AlloyDBCreateClusterOperator(AlloyDBWriteBaseOperator): + """ + Create an Alloy DB cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AlloyDBCreateClusterOperator` + + :param cluster_id: Required. ID of the cluster to create. + :param cluster_configuration: Required. Cluster to create. For more details please see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.Cluster + :param is_secondary: Required. Specifies if the Cluster to be created is Primary or Secondary. + Please note, if set True, then specify the `secondary_config` field in the cluster so the created + secondary cluster was pointing to the primary cluster. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + so that if you must retry your request, the server ignores the request if it has already been + completed. The server guarantees that for at least 60 minutes since the first request. + For example, consider a situation where you make an initial request and the request times out. + If you make the request again with the same request ID, the server can check if the original operation + with the same request ID was received, and if so, ignores the second request. + This prevents clients from accidentally creating duplicate commitments. + The request ID must be a valid UUID with the exception that zero UUID is not supported + (00000000-0000-0000-0000-000000000000). + :param validate_request: Optional. If set, performs request validation, but does not actually + execute the request. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "is_secondary"} | set(AlloyDBWriteBaseOperator.template_fields) + ) + operator_extra_links = (AlloyDBClusterLink(),) + + def __init__( + self, + cluster_id: str, + cluster_configuration: alloydb_v1.Cluster | dict, + is_secondary: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.cluster_configuration = cluster_configuration + self.is_secondary = is_secondary + + def execute(self, context: Context) -> Any: + message = ( + "Validating a Create AlloyDB cluster request." + if self.validate_request + else "Creating an AlloyDB cluster." + ) + self.log.info(message) + + try: + create_method = ( + self.hook.create_secondary_cluster if self.is_secondary else self.hook.create_cluster + ) + operation = create_method( + cluster_id=self.cluster_id, + cluster=self.cluster_configuration, + location=self.location, + project_id=self.project_id, + request_id=self.request_id, + validate_only=self.validate_request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except AlreadyExists: + self.log.info("AlloyDB cluster %s already exists.", self.cluster_id) + result = self.hook.get_cluster( + cluster_id=self.cluster_id, + location=self.location, + project_id=self.project_id, + ) + result = alloydb_v1.Cluster.to_dict(result) + except InvalidArgument as ex: + if "cannot create more than one secondary cluster per primary cluster" in ex.message: + result = self.hook.get_cluster( + cluster_id=self.cluster_id, + location=self.location, + project_id=self.project_id, + ) + result = alloydb_v1.Cluster.to_dict(result) + self.log.info("AlloyDB cluster %s already exists.", result.get("name").split("/")[-1]) + else: + raise AirflowException(ex.message) + except Exception as ex: + raise AirflowException(ex) + else: + operation_result = self.get_operation_result(operation) + result = alloydb_v1.Cluster.to_dict(operation_result) if operation_result else None + + if result: + AlloyDBClusterLink.persist( + context=context, + task_instance=self, + location_id=self.location, + cluster_id=self.cluster_id, + project_id=self.project_id, + ) + + return result + + +class AlloyDBUpdateClusterOperator(AlloyDBWriteBaseOperator): + """ + Update an Alloy DB cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AlloyDBUpdateClusterOperator` + + :param cluster_id: Required. ID of the cluster to create. + :param cluster_configuration: Required. Cluster to update. For more details please see API documentation: + https://cloud.google.com/python/docs/reference/alloydb/latest/google.cloud.alloydb_v1.types.Cluster + :param update_mask: Optional. Field mask is used to specify the fields to be overwritten in the + Cluster resource by the update. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + so that if you must retry your request, the server ignores the request if it has already been + completed. The server guarantees that for at least 60 minutes since the first request. + For example, consider a situation where you make an initial request and the request times out. + If you make the request again with the same request ID, the server can check if the original operation + with the same request ID was received, and if so, ignores the second request. + This prevents clients from accidentally creating duplicate commitments. + The request ID must be a valid UUID with the exception that zero UUID is not supported + (00000000-0000-0000-0000-000000000000). + :param validate_request: Optional. If set, performs request validation, but does not actually + execute the request. + :param allow_missing: Optional. If set to true, update succeeds even if cluster is not found. + In that case, a new cluster is created and update_mask is ignored. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "allow_missing"} | set(AlloyDBWriteBaseOperator.template_fields) + ) + operator_extra_links = (AlloyDBClusterLink(),) + + def __init__( + self, + cluster_id: str, + cluster_configuration: alloydb_v1.Cluster | dict, + update_mask: FieldMask | dict | None = None, + allow_missing: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.cluster_configuration = cluster_configuration + self.update_mask = update_mask + self.allow_missing = allow_missing + + def execute(self, context: Context) -> Any: + message = ( + "Validating an Update AlloyDB cluster request." + if self.validate_request + else "Updating an AlloyDB cluster." + ) + self.log.info(message) + + try: + operation = self.hook.update_cluster( + cluster_id=self.cluster_id, + project_id=self.project_id, + location=self.location, + cluster=self.cluster_configuration, + update_mask=self.update_mask, + allow_missing=self.allow_missing, + request_id=self.request_id, + validate_only=self.validate_request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except Exception as ex: + raise AirflowException(ex) from ex + else: + operation_result = self.get_operation_result(operation) + result = alloydb_v1.Cluster.to_dict(operation_result) if operation_result else None + + AlloyDBClusterLink.persist( + context=context, + task_instance=self, + location_id=self.location, + cluster_id=self.cluster_id, + project_id=self.project_id, + ) + + if not self.validate_request: + self.log.info("AlloyDB cluster %s was successfully updated.", self.cluster_id) + return result + + +class AlloyDBDeleteClusterOperator(AlloyDBWriteBaseOperator): + """ + Delete an Alloy DB cluster. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:AlloyDBDeleteClusterOperator` + + :param cluster_id: Required. ID of the cluster to create. + :param request_id: Optional. An optional request ID to identify requests. Specify a unique request ID + so that if you must retry your request, the server ignores the request if it has already been + completed. The server guarantees that for at least 60 minutes since the first request. + For example, consider a situation where you make an initial request and the request times out. + If you make the request again with the same request ID, the server can check if the original operation + with the same request ID was received, and if so, ignores the second request. + This prevents clients from accidentally creating duplicate commitments. + The request ID must be a valid UUID with the exception that zero UUID is not supported + (00000000-0000-0000-0000-000000000000). + :param validate_request: Optional. If set, performs request validation, but does not actually + execute the request. + :param etag: Optional. The current etag of the Cluster. If an etag is provided and does not match the + current etag of the Cluster, deletion will be blocked and an ABORTED error will be returned. + :param force: Optional. Whether to cascade delete child instances for given cluster. + :param project_id: Required. The ID of the Google Cloud project where the service is used. + :param location: Required. The ID of the Google Cloud region where the service is used. + :param gcp_conn_id: Optional. The connection ID to use to connect to Google Cloud. + :param retry: Optional. A retry object used to retry requests. If `None` is specified, requests will not + be retried. + :param timeout: Optional. The amount of time, in seconds, to wait for the request to complete. + Note that if `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Optional. Additional metadata that is provided to the method. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_id", "etag", "force"} | set(AlloyDBWriteBaseOperator.template_fields) + ) + + def __init__( + self, + cluster_id: str, + etag: str | None = None, + force: bool = False, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.cluster_id = cluster_id + self.etag = etag + self.force = force + + def execute(self, context: Context) -> Any: + message = ( + "Validating a Delete AlloyDB cluster request." + if self.validate_request + else "Deleting an AlloyDB cluster." + ) + self.log.info(message) + + try: + operation = self.hook.delete_cluster( + cluster_id=self.cluster_id, + project_id=self.project_id, + location=self.location, + etag=self.etag, + force=self.force, + request_id=self.request_id, + validate_only=self.validate_request, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except Exception as ex: + raise AirflowException(ex) from ex + else: + _ = self.get_operation_result(operation) + + if not self.validate_request: + self.log.info("AlloyDB cluster %s was successfully removed.", self.cluster_id) diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 61fd6f9b98a74..43157a220f5d1 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -119,6 +119,7 @@ dependencies: - google-auth>=2.29.0 - google-auth-httplib2>=0.0.1 - google-cloud-aiplatform>=1.70.0 + - google-cloud-alloydb - google-cloud-automl>=2.12.0 # Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0 - google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.* @@ -240,6 +241,11 @@ integrations: external-doc-url: https://cloud.google.com/bigtable/ logo: /integration-logos/gcp/Cloud-Bigtable.png tags: [gcp] + - integration-name: Google Cloud AlloyDB + external-doc-url: https://cloud.google.com/alloydb + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/alloy_db.rst + tags: [gcp] - integration-name: Google Cloud Build external-doc-url: https://cloud.google.com/build/ how-to-guide: @@ -548,6 +554,9 @@ operators: - integration-name: Google Cloud Common python-modules: - airflow.providers.google.cloud.operators.cloud_base + - integration-name: Google Cloud AlloyDB + python-modules: + - airflow.providers.google.cloud.operators.alloy_db - integration-name: Google AutoML python-modules: - airflow.providers.google.cloud.operators.automl @@ -806,6 +815,9 @@ hooks: - integration-name: Google Bigtable python-modules: - airflow.providers.google.cloud.hooks.bigtable + - integration-name: Google Cloud AlloyDB + python-modules: + - airflow.providers.google.cloud.hooks.alloy_db - integration-name: Google Cloud Build python-modules: - airflow.providers.google.cloud.hooks.cloud_build @@ -1180,6 +1192,7 @@ connection-types: connection-type: leveldb extra-links: + - airflow.providers.google.cloud.links.alloy_db.AlloyDBClusterLink - airflow.providers.google.cloud.links.dataform.DataformRepositoryLink - airflow.providers.google.cloud.links.dataform.DataformWorkspaceLink - airflow.providers.google.cloud.links.dataform.DataformWorkflowInvocationLink diff --git a/providers/tests/google/cloud/hooks/test_alloy_db.py b/providers/tests/google/cloud/hooks/test_alloy_db.py new file mode 100644 index 0000000000000..6a856107d57f9 --- /dev/null +++ b/providers/tests/google/cloud/hooks/test_alloy_db.py @@ -0,0 +1,295 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from copy import deepcopy +from unittest import mock + +import pytest +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud import alloydb_v1 + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.alloy_db import AlloyDbHook +from airflow.providers.google.common.consts import CLIENT_INFO + +TEST_GCP_PROJECT = "test-project" +TEST_GCP_REGION = "global" +TEST_GCP_CONN_ID = "test_conn_id" +TEST_IMPERSONATION_CHAIN = "test_impersonation_chain" + +TEST_CLUSTER_ID = "test_cluster_id" +TEST_CLUSTER = {} +TEST_CLUSTER_NAME = f"projects/{TEST_GCP_PROJECT}/locations/{TEST_GCP_REGION}/clusters/{TEST_CLUSTER_ID}" +TEST_UPDATE_MASK = None +TEST_ALLOW_MISSING = False +TEST_ETAG = "test-etag" +TEST_FORCE = False +TEST_REQUEST_ID = "test_request_id" +TEST_VALIDATE_ONLY = False + +TEST_RETRY = DEFAULT +TEST_TIMEOUT = None +TEST_METADATA = () + +HOOK_PATH = "airflow.providers.google.cloud.hooks.alloy_db.{}" + + +class TestAlloyDbHook: + def setup_method(self): + with mock.patch("airflow.hooks.base.BaseHook.get_connection"): + self.hook = AlloyDbHook( + gcp_conn_id=TEST_GCP_CONN_ID, + ) + + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_credentials")) + @mock.patch(HOOK_PATH.format("alloydb_v1.AlloyDBAdminClient")) + def test_get_conn(self, mock_client, mock_get_credentials): + mock_credentials = mock_get_credentials.return_value + expected_client = mock_client.return_value + + client = self.hook.get_conn() + + assert client == expected_client + mock_get_credentials.assert_called_once() + mock_client.assert_called_once_with( + credentials=mock_credentials, + client_info=CLIENT_INFO, + ) + + @pytest.mark.parametrize( + "given_timeout, expected_timeout", + [ + (None, None), + (0.0, None), + (10.0, 10), + (10.9, 10), + ], + ) + @mock.patch(HOOK_PATH.format("AlloyDbHook.log")) + def test_wait_for_operation(self, mock_log, given_timeout, expected_timeout): + mock_operation = mock.MagicMock() + expected_operation_result = mock_operation.result.return_value + + result = self.hook.wait_for_operation(timeout=given_timeout, operation=mock_operation) + + assert result == expected_operation_result + mock_log.info.assert_called_once_with("Waiting for operation to complete...") + mock_operation.result.assert_called_once_with(timeout=expected_timeout) + + @pytest.mark.parametrize( + "given_timeout, expected_timeout", + [ + (None, None), + (0.0, None), + (10.0, 10), + (10.9, 10), + ], + ) + @mock.patch(HOOK_PATH.format("AlloyDbHook.log")) + def test_wait_for_operation_exception(self, mock_log, given_timeout, expected_timeout): + mock_operation = mock.MagicMock() + mock_operation.result.side_effect = Exception + + with pytest.raises(AirflowException): + self.hook.wait_for_operation(timeout=given_timeout, operation=mock_operation) + + mock_log.info.assert_called_once_with("Waiting for operation to complete...") + mock_operation.result.assert_called_once_with(timeout=expected_timeout) + mock_operation.exception.assert_called_once_with(timeout=expected_timeout) + + def test_cluster_name(self): + cluster_name = self.hook.cluster_name(TEST_GCP_PROJECT, TEST_GCP_REGION, TEST_CLUSTER_ID) + assert cluster_name == TEST_CLUSTER_NAME + + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_conn")) + def test_create_cluster(self, mock_client): + mock_create_cluster = mock_client.return_value.create_cluster + expected_result = mock_create_cluster.return_value + expected_request = { + "parent": f"projects/{TEST_GCP_PROJECT}/locations/{TEST_GCP_REGION}", + "cluster_id": TEST_CLUSTER_ID, + "cluster": TEST_CLUSTER, + "request_id": TEST_REQUEST_ID, + "validate_only": TEST_VALIDATE_ONLY, + } + + result = self.hook.create_cluster( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + assert result == expected_result + mock_client.assert_called_once() + mock_create_cluster.assert_called_once_with( + request=expected_request, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_conn")) + def test_create_secondary_cluster(self, mock_client): + mock_create_secondary_cluster = mock_client.return_value.create_secondary_cluster + expected_result = mock_create_secondary_cluster.return_value + expected_request = { + "parent": f"projects/{TEST_GCP_PROJECT}/locations/{TEST_GCP_REGION}", + "cluster_id": TEST_CLUSTER_ID, + "cluster": TEST_CLUSTER, + "request_id": TEST_REQUEST_ID, + "validate_only": TEST_VALIDATE_ONLY, + } + + result = self.hook.create_secondary_cluster( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + assert result == expected_result + mock_client.assert_called_once() + mock_create_secondary_cluster.assert_called_once_with( + request=expected_request, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_conn")) + def test_get_cluster(self, mock_client): + mock_get_cluster = mock_client.return_value.get_cluster + expected_result = mock_get_cluster.return_value + + result = self.hook.get_cluster( + cluster_id=TEST_CLUSTER_ID, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + assert result == expected_result + mock_client.assert_called_once() + mock_get_cluster.assert_called_once_with( + request={"name": TEST_CLUSTER_NAME}, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @pytest.mark.parametrize( + "given_cluster, expected_cluster", + [ + (TEST_CLUSTER, {**deepcopy(TEST_CLUSTER), **{"name": TEST_CLUSTER_NAME}}), + (alloydb_v1.Cluster(), {"name": TEST_CLUSTER_NAME}), + ({}, {"name": TEST_CLUSTER_NAME}), + ], + ) + @mock.patch(HOOK_PATH.format("deepcopy")) + @mock.patch(HOOK_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_conn")) + def test_update_cluster(self, mock_client, mock_to_dict, mock_deepcopy, given_cluster, expected_cluster): + mock_update_cluster = mock_client.return_value.update_cluster + expected_result = mock_update_cluster.return_value + mock_deepcopy.return_value = expected_cluster + mock_to_dict.return_value = expected_cluster + + expected_request = { + "update_mask": TEST_UPDATE_MASK, + "cluster": expected_cluster, + "request_id": TEST_REQUEST_ID, + "validate_only": TEST_VALIDATE_ONLY, + "allow_missing": TEST_ALLOW_MISSING, + } + + result = self.hook.update_cluster( + cluster_id=TEST_CLUSTER_ID, + cluster=given_cluster, + location=TEST_GCP_REGION, + update_mask=TEST_UPDATE_MASK, + project_id=TEST_GCP_PROJECT, + allow_missing=TEST_ALLOW_MISSING, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + assert result == expected_result + if isinstance(given_cluster, dict): + mock_deepcopy.assert_called_once_with(given_cluster) + assert not mock_to_dict.called + else: + assert not mock_deepcopy.called + mock_to_dict.assert_called_once_with(given_cluster) + mock_client.assert_called_once() + mock_update_cluster.assert_called_once_with( + request=expected_request, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + @mock.patch(HOOK_PATH.format("AlloyDbHook.get_conn")) + def test_delete_cluster(self, mock_client): + mock_delete_cluster = mock_client.return_value.delete_cluster + expected_result = mock_delete_cluster.return_value + expected_request = { + "name": TEST_CLUSTER_NAME, + "request_id": TEST_REQUEST_ID, + "etag": TEST_ETAG, + "validate_only": TEST_VALIDATE_ONLY, + "force": TEST_FORCE, + } + + result = self.hook.delete_cluster( + cluster_id=TEST_CLUSTER_ID, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + etag=TEST_ETAG, + validate_only=TEST_VALIDATE_ONLY, + force=TEST_FORCE, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + + assert result == expected_result + mock_client.assert_called_once() + mock_delete_cluster.assert_called_once_with( + request=expected_request, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) diff --git a/providers/tests/google/cloud/links/test_alloy_db.py b/providers/tests/google/cloud/links/test_alloy_db.py new file mode 100644 index 0000000000000..26eebc99a4c88 --- /dev/null +++ b/providers/tests/google/cloud/links/test_alloy_db.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.links.alloy_db import AlloyDBClusterLink + +TEST_LOCATION = "test-location" +TEST_CLUSTER_ID = "test-cluster-id" +TEST_PROJECT_ID = "test-project-id" +EXPECTED_ALLOY_DB_CLUSTER_LINK_NAME = "AlloyDB Cluster" +EXPECTED_ALLOY_DB_CLUSTER_LINK_KEY = "alloy_db_cluster" +EXPECTED_ALLOY_DB_CLUSTER_LINK_FORMAT_STR = ( + "/alloydb/locations/{location_id}/clusters/{cluster_id}?project={project_id}" +) + + +class TestAlloyDBClusterLink: + def test_class_attributes(self): + assert AlloyDBClusterLink.key == EXPECTED_ALLOY_DB_CLUSTER_LINK_KEY + assert AlloyDBClusterLink.name == EXPECTED_ALLOY_DB_CLUSTER_LINK_NAME + assert AlloyDBClusterLink.format_str == EXPECTED_ALLOY_DB_CLUSTER_LINK_FORMAT_STR + + def test_persist(self): + mock_context, mock_task_instance = mock.MagicMock(), mock.MagicMock() + + AlloyDBClusterLink.persist( + context=mock_context, + task_instance=mock_task_instance, + location_id=TEST_LOCATION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_PROJECT_ID, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + mock_context, + key=EXPECTED_ALLOY_DB_CLUSTER_LINK_KEY, + value={ + "location_id": TEST_LOCATION, + "cluster_id": TEST_CLUSTER_ID, + "project_id": TEST_PROJECT_ID, + }, + ) diff --git a/providers/tests/google/cloud/operators/test_alloy_db.py b/providers/tests/google/cloud/operators/test_alloy_db.py new file mode 100644 index 0000000000000..8c2012e8e9bc4 --- /dev/null +++ b/providers/tests/google/cloud/operators/test_alloy_db.py @@ -0,0 +1,789 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import call + +import pytest +from google.api_core.exceptions import AlreadyExists, InvalidArgument +from google.api_core.gapic_v1.method import DEFAULT + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.operators.alloy_db import ( + AlloyDBBaseOperator, + AlloyDBCreateClusterOperator, + AlloyDBDeleteClusterOperator, + AlloyDBUpdateClusterOperator, + AlloyDBWriteBaseOperator, +) + +TEST_TASK_ID = "test-task-id" +TEST_GCP_PROJECT = "test-project" +TEST_GCP_REGION = "global" +TEST_GCP_CONN_ID = "test_conn_id" +TEST_IMPERSONATION_CHAIN = "test_impersonation_chain" +TEST_RETRY = DEFAULT +TEST_TIMEOUT = None +TEST_METADATA = () + +TEST_REQUEST_ID = "test_request_id" +TEST_VALIDATE_ONLY = False + +TEST_CLUSTER_ID = "test_cluster_id" +TEST_CLUSTER_NAME = f"projects/{TEST_GCP_PROJECT}/locations/{TEST_GCP_REGION}/clusters/{TEST_CLUSTER_ID}" +TEST_CLUSTER = {} +TEST_IS_SECONDARY = False +TEST_UPDATE_MASK = None +TEST_ALLOW_MISSING = False +TEST_ETAG = "test-etag" +TEST_FORCE = False + +OPERATOR_MODULE_PATH = "airflow.providers.google.cloud.operators.alloy_db.{}" + + +class TestAlloyDBBaseOperator: + def setup_method(self): + self.operator = AlloyDBBaseOperator( + task_id=TEST_TASK_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_init(self): + assert self.operator.project_id == TEST_GCP_PROJECT + assert self.operator.location == TEST_GCP_REGION + assert self.operator.gcp_conn_id == TEST_GCP_CONN_ID + assert self.operator.impersonation_chain == TEST_IMPERSONATION_CHAIN + assert self.operator.retry == TEST_RETRY + assert self.operator.timeout == TEST_TIMEOUT + assert self.operator.metadata == TEST_METADATA + + def test_template_fields(self): + expected_template_fields = {"project_id", "location", "gcp_conn_id"} + assert set(AlloyDBBaseOperator.template_fields) == expected_template_fields + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook")) + def test_hook(self, mock_hook): + expected_hook = mock_hook.return_value + + hook_1 = self.operator.hook + hook_2 = self.operator.hook + + mock_hook.assert_called_once_with( + gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN + ) + assert hook_1 == expected_hook + assert hook_2 == expected_hook + + +class TestAlloyDBWriteBaseOperator: + def setup_method(self): + self.operator = AlloyDBWriteBaseOperator( + task_id=TEST_TASK_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + request_id=TEST_REQUEST_ID, + validate_request=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_init(self): + assert self.operator.request_id == TEST_REQUEST_ID + assert self.operator.validate_request == TEST_VALIDATE_ONLY + + def test_template_fields(self): + expected_template_fields = {"request_id", "validate_request"} | set( + AlloyDBBaseOperator.template_fields + ) + assert set(AlloyDBWriteBaseOperator.template_fields) == expected_template_fields + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBWriteBaseOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook")) + def test_get_operation_result(self, mock_hook, mock_log): + mock_operation = mock.MagicMock() + mock_wait_for_operation = mock_hook.return_value.wait_for_operation + expected_result = mock_wait_for_operation.return_value + + result = self.operator.get_operation_result(mock_operation) + + assert result == expected_result + assert not mock_log.called + mock_wait_for_operation.assert_called_once_with(timeout=TEST_TIMEOUT, operation=mock_operation) + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBWriteBaseOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook")) + def test_get_operation_result_validate_result(self, mock_hook, mock_log): + mock_operation = mock.MagicMock() + mock_wait_for_operation = mock_hook.return_value.wait_for_operation + self.operator.validate_request = True + + result = self.operator.get_operation_result(mock_operation) + + assert result is None + mock_log.info.assert_called_once_with("The request validation has been passed successfully!") + assert not mock_wait_for_operation.called + + +class TestAlloyDBCreateClusterOperator: + def setup_method(self): + self.operator = AlloyDBCreateClusterOperator( + task_id=TEST_TASK_ID, + cluster_id=TEST_CLUSTER_ID, + cluster_configuration=TEST_CLUSTER, + is_secondary=TEST_IS_SECONDARY, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + request_id=TEST_REQUEST_ID, + validate_request=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_init(self): + assert self.operator.cluster_id == TEST_CLUSTER_ID + assert self.operator.cluster_configuration == TEST_CLUSTER + assert self.operator.is_secondary == TEST_IS_SECONDARY + + def test_template_fields(self): + expected_template_fields = {"cluster_id", "is_secondary"} | set( + AlloyDBWriteBaseOperator.template_fields + ) + assert set(AlloyDBCreateClusterOperator.template_fields) == expected_template_fields + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_operation = mock_create_cluster.return_value + mock_operation_result = mock_get_operation_result.return_value + + expected_message = "Creating an AlloyDB cluster." + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with(expected_message) + mock_create_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_secondary_cluster.called + mock_to_dict.assert_called_once_with(mock_operation_result) + mock_get_operation_result.assert_called_once_with(mock_operation) + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result == expected_result + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_is_secondary( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_operation = mock_create_secondary_cluster.return_value + mock_operation_result = mock_get_operation_result.return_value + + expected_message = "Creating an AlloyDB cluster." + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + self.operator.is_secondary = True + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with(expected_message) + assert not mock_create_cluster.called + mock_create_secondary_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_to_dict.assert_called_once_with(mock_operation_result) + mock_get_operation_result.assert_called_once_with(mock_operation) + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result == expected_result + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_validate_request( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_operation = mock_create_cluster.return_value + mock_get_operation_result.return_value = None + + expected_message = "Validating a Create AlloyDB cluster request." + mock_context = mock.MagicMock() + self.operator.validate_request = True + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with(expected_message) + mock_create_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=True, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_secondary_cluster.called + assert not mock_to_dict.called + assert not mock_link.persist.called + mock_get_operation_result.assert_called_once_with(mock_operation) + assert result is None + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_validate_request_is_secondary( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_operation = mock_create_secondary_cluster.return_value + mock_get_operation_result.return_value = None + + expected_message = "Validating a Create AlloyDB cluster request." + mock_context = mock.MagicMock() + self.operator.validate_request = True + self.operator.is_secondary = True + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with(expected_message) + mock_create_secondary_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=True, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_cluster.called + assert not mock_to_dict.called + assert not mock_link.persist.called + mock_get_operation_result.assert_called_once_with(mock_operation) + assert result is None + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_already_exists( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_cluster.side_effect = AlreadyExists("test-message") + + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_get_cluster = mock_hook.return_value.get_cluster + mock_get_cluster_result = mock_get_cluster.return_value + + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_has_calls( + [ + call("Creating an AlloyDB cluster."), + call("AlloyDB cluster %s already exists.", TEST_CLUSTER_ID), + ] + ) + mock_create_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_secondary_cluster.called + mock_get_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + ) + mock_to_dict.assert_called_once_with(mock_get_cluster_result) + assert not mock_get_operation_result.called + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result == expected_result + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_invalid_argument( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + expected_error_message = "cannot create more than one secondary cluster per primary cluster" + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_create_secondary_cluster.side_effect = InvalidArgument(message=expected_error_message) + + mock_get_cluster = mock_hook.return_value.get_cluster + mock_get_cluster_result = mock_get_cluster.return_value + + expected_result = mock_to_dict.return_value + expected_result.get.return_value = TEST_CLUSTER_NAME + mock_context = mock.MagicMock() + self.operator.is_secondary = True + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_has_calls( + [ + call("Creating an AlloyDB cluster."), + call("AlloyDB cluster %s already exists.", TEST_CLUSTER_ID), + ] + ) + mock_create_secondary_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_cluster.called + mock_get_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + ) + mock_to_dict.assert_called_once_with(mock_get_cluster_result) + assert not mock_get_operation_result.called + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result == expected_result + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_invalid_argument_exception( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_create_secondary_cluster.side_effect = InvalidArgument(message="Test error") + mock_get_cluster = mock_hook.return_value.get_cluster + expected_result = mock_to_dict.return_value + expected_result.get.return_value = TEST_CLUSTER_NAME + mock_context = mock.MagicMock() + self.operator.is_secondary = True + + with pytest.raises(AirflowException): + self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with("Creating an AlloyDB cluster.") + mock_create_secondary_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_cluster.called + assert not mock_get_cluster.called + assert not mock_to_dict.called + assert not mock_get_operation_result.called + assert not mock_link.persist.called + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBCreateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + mock_create_cluster = mock_hook.return_value.create_cluster + mock_create_secondary_cluster = mock_hook.return_value.create_secondary_cluster + mock_create_cluster.side_effect = Exception() + mock_get_cluster = mock_hook.return_value.get_cluster + expected_result = mock_to_dict.return_value + expected_result.get.return_value = TEST_CLUSTER_NAME + mock_context = mock.MagicMock() + + with pytest.raises(AirflowException): + self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with("Creating an AlloyDB cluster.") + mock_create_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + cluster=TEST_CLUSTER, + location=TEST_GCP_REGION, + project_id=TEST_GCP_PROJECT, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_create_secondary_cluster.called + assert not mock_get_cluster.called + assert not mock_to_dict.called + assert not mock_get_operation_result.called + assert not mock_link.persist.called + + +class TestAlloyDBUpdateClusterOperator: + def setup_method(self): + self.operator = AlloyDBUpdateClusterOperator( + task_id=TEST_TASK_ID, + cluster_id=TEST_CLUSTER_ID, + cluster_configuration=TEST_CLUSTER, + update_mask=TEST_UPDATE_MASK, + allow_missing=TEST_ALLOW_MISSING, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + request_id=TEST_REQUEST_ID, + validate_request=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_init(self): + assert self.operator.cluster_id == TEST_CLUSTER_ID + assert self.operator.cluster_configuration == TEST_CLUSTER + assert self.operator.update_mask == TEST_UPDATE_MASK + assert self.operator.allow_missing == TEST_ALLOW_MISSING + + def test_template_fields(self): + expected_template_fields = {"cluster_id", "allow_missing"} | set( + AlloyDBWriteBaseOperator.template_fields + ) + assert set(AlloyDBUpdateClusterOperator.template_fields) == expected_template_fields + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + mock_update_cluster = mock_hook.return_value.update_cluster + mock_operation = mock_update_cluster.return_value + mock_operation_result = mock_get_operation_result.return_value + + expected_result = mock_to_dict.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_update_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + cluster=TEST_CLUSTER, + update_mask=TEST_UPDATE_MASK, + allow_missing=TEST_ALLOW_MISSING, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_get_operation_result.assert_called_once_with(mock_operation) + mock_to_dict.assert_called_once_with(mock_operation_result) + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result == expected_result + mock_log.info.assert_has_calls( + [ + call("Updating an AlloyDB cluster."), + call("AlloyDB cluster %s was successfully updated.", TEST_CLUSTER_ID), + ] + ) + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_validate_request( + self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link + ): + mock_update_cluster = mock_hook.return_value.update_cluster + mock_operation = mock_update_cluster.return_value + mock_get_operation_result.return_value = None + + expected_message = "Validating an Update AlloyDB cluster request." + mock_context = mock.MagicMock() + self.operator.validate_request = True + + result = self.operator.execute(context=mock_context) + + mock_log.info.assert_called_once_with(expected_message) + mock_update_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + cluster=TEST_CLUSTER, + update_mask=TEST_UPDATE_MASK, + allow_missing=TEST_ALLOW_MISSING, + request_id=TEST_REQUEST_ID, + validate_only=True, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_get_operation_result.assert_called_once_with(mock_operation) + assert not mock_to_dict.called + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=self.operator, + location_id=TEST_GCP_REGION, + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + ) + assert result is None + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBClusterLink")) + @mock.patch(OPERATOR_MODULE_PATH.format("alloydb_v1.Cluster.to_dict")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBUpdateClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result, mock_to_dict, mock_link): + mock_update_cluster = mock_hook.return_value.update_cluster + mock_update_cluster.side_effect = Exception + + mock_context = mock.MagicMock() + + with pytest.raises(AirflowException): + self.operator.execute(context=mock_context) + + mock_update_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + cluster=TEST_CLUSTER, + update_mask=TEST_UPDATE_MASK, + allow_missing=TEST_ALLOW_MISSING, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_get_operation_result.called + assert not mock_to_dict.called + assert not mock_link.persist.called + mock_log.info.assert_called_once_with("Updating an AlloyDB cluster.") + + +class TestAlloyDBDeleteClusterOperator: + def setup_method(self): + self.operator = AlloyDBDeleteClusterOperator( + task_id=TEST_TASK_ID, + cluster_id=TEST_CLUSTER_ID, + etag=TEST_ETAG, + force=TEST_FORCE, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + gcp_conn_id=TEST_GCP_CONN_ID, + request_id=TEST_REQUEST_ID, + validate_request=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + def test_init(self): + assert self.operator.cluster_id == TEST_CLUSTER_ID + assert self.operator.etag == TEST_ETAG + assert self.operator.force == TEST_FORCE + + def test_template_fields(self): + expected_template_fields = {"cluster_id", "etag", "force"} | set( + AlloyDBWriteBaseOperator.template_fields + ) + assert set(AlloyDBDeleteClusterOperator.template_fields) == expected_template_fields + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute(self, mock_hook, mock_log, mock_get_operation_result): + mock_delete_cluster = mock_hook.return_value.delete_cluster + mock_operation = mock_delete_cluster.return_value + mock_context = mock.MagicMock() + + result = self.operator.execute(context=mock_context) + + mock_delete_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + etag=TEST_ETAG, + force=TEST_FORCE, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_get_operation_result.assert_called_once_with(mock_operation) + assert result is None + mock_log.info.assert_has_calls( + [ + call("Deleting an AlloyDB cluster."), + call("AlloyDB cluster %s was successfully removed.", TEST_CLUSTER_ID), + ] + ) + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_validate_request(self, mock_hook, mock_log, mock_get_operation_result): + mock_delete_cluster = mock_hook.return_value.delete_cluster + mock_operation = mock_delete_cluster.return_value + mock_context = mock.MagicMock() + self.operator.validate_request = True + + result = self.operator.execute(context=mock_context) + + mock_delete_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + etag=TEST_ETAG, + force=TEST_FORCE, + request_id=TEST_REQUEST_ID, + validate_only=True, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + mock_get_operation_result.assert_called_once_with(mock_operation) + assert result is None + mock_log.info.assert_called_once_with("Validating a Delete AlloyDB cluster request.") + + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.get_operation_result")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDBDeleteClusterOperator.log")) + @mock.patch(OPERATOR_MODULE_PATH.format("AlloyDbHook"), new_callable=mock.PropertyMock) + def test_execute_exception(self, mock_hook, mock_log, mock_get_operation_result): + mock_delete_cluster = mock_hook.return_value.delete_cluster + mock_delete_cluster.side_effect = Exception + mock_context = mock.MagicMock() + + with pytest.raises(AirflowException): + _ = self.operator.execute(context=mock_context) + + mock_delete_cluster.assert_called_once_with( + cluster_id=TEST_CLUSTER_ID, + project_id=TEST_GCP_PROJECT, + location=TEST_GCP_REGION, + etag=TEST_ETAG, + force=TEST_FORCE, + request_id=TEST_REQUEST_ID, + validate_only=TEST_VALIDATE_ONLY, + retry=TEST_RETRY, + timeout=TEST_TIMEOUT, + metadata=TEST_METADATA, + ) + assert not mock_get_operation_result.called + mock_log.info.assert_called_once_with("Deleting an AlloyDB cluster.") diff --git a/providers/tests/system/google/cloud/alloy_db/__init__.py b/providers/tests/system/google/cloud/alloy_db/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/system/google/cloud/alloy_db/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/tests/system/google/cloud/alloy_db/example_alloy_db.py b/providers/tests/system/google/cloud/alloy_db/example_alloy_db.py new file mode 100644 index 0000000000000..e8f40562f44b5 --- /dev/null +++ b/providers/tests/system/google/cloud/alloy_db/example_alloy_db.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for Google AlloyDB operators. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.alloy_db import ( + AlloyDBCreateClusterOperator, + AlloyDBDeleteClusterOperator, + AlloyDBUpdateClusterOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "alloy_db" +GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") + +GCP_LOCATION = "europe-north1" +GCP_LOCATION_SECONDARY = "europe-west1" +GCP_NETWORK = "default" +CLUSTER_USER = "postgres-test" +CLUSTER_USER_PASSWORD = "postgres-test-pa$$w0rd" +CLUSTER_ID = f"cluster-{DAG_ID}-{ENV_ID}".replace("_", "-") +SECONDARY_CLUSTER_ID = f"cluster-secondary-{DAG_ID}-{ENV_ID}".replace("_", "-") +CLUSTER = { + "network": f"projects/{GCP_PROJECT_ID}/global/networks/{GCP_NETWORK}", + "initial_user": { + "user": CLUSTER_USER, + "password": CLUSTER_USER_PASSWORD, + }, +} +CLUSTER_UPDATE = { + "automated_backup_policy": { + "enabled": True, + } +} +CLUSTER_UPDATE_MASK = {"paths": ["automated_backup_policy.enabled"]} +SECONDARY_CLUSTER = { + "network": f"projects/{GCP_PROJECT_ID}/global/networks/{GCP_NETWORK}", + "secondary_config": { + "primary_cluster_name": f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/clusters/{CLUSTER_ID}", + }, +} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "alloy-db"], +) as dag: + # [START howto_operator_alloy_db_create_cluster] + create_cluster = AlloyDBCreateClusterOperator( + task_id="create_cluster", + cluster_id=CLUSTER_ID, + cluster_configuration=CLUSTER, + is_secondary=False, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_alloy_db_create_cluster] + + # [START howto_operator_alloy_db_update_cluster] + update_cluster = AlloyDBUpdateClusterOperator( + task_id="update_cluster", + cluster_id=CLUSTER_ID, + cluster_configuration=CLUSTER_UPDATE, + update_mask=CLUSTER_UPDATE_MASK, + location=GCP_LOCATION, + project_id=GCP_PROJECT_ID, + ) + # [END howto_operator_alloy_db_update_cluster] + + create_secondary_cluster = AlloyDBCreateClusterOperator( + task_id="create_secondary_cluster", + cluster_id=SECONDARY_CLUSTER_ID, + cluster_configuration=SECONDARY_CLUSTER, + is_secondary=True, + location=GCP_LOCATION_SECONDARY, + project_id=GCP_PROJECT_ID, + ) + + delete_secondary_cluster = AlloyDBDeleteClusterOperator( + task_id="delete_secondary_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION_SECONDARY, + cluster_id=SECONDARY_CLUSTER_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + + # [START howto_operator_alloy_db_delete_cluster] + delete_cluster = AlloyDBDeleteClusterOperator( + task_id="delete_cluster", + project_id=GCP_PROJECT_ID, + location=GCP_LOCATION, + cluster_id=CLUSTER_ID, + ) + # [END howto_operator_alloy_db_delete_cluster] + + delete_cluster.trigger_rule = TriggerRule.ALL_DONE + + create_cluster >> update_cluster >> create_secondary_cluster >> delete_secondary_cluster >> delete_cluster + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index ad24f34e0c32b..9a2d0d59e5f40 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -381,6 +381,8 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest } BASE_CLASSES = { + "airflow.providers.google.cloud.operators.alloy_db.AlloyDBBaseOperator", + "airflow.providers.google.cloud.operators.alloy_db.AlloyDBWriteBaseOperator", "airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator", "airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator", "airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator",