From 0997e077fee90d7238834e0d02c3650b6bbb0096 Mon Sep 17 00:00:00 2001 From: Oleg Kachur Date: Thu, 7 Nov 2024 21:31:19 +0000 Subject: [PATCH] Introduce gcp advance API (V3) translate native datasets operators - Add support for native datasets for Cloud Translation API. - The datasets created via automl API are considered legacy, as they keep been supported, all new enhancements will be avaliable for native datasets(reccomended), created by Cloud Translate API, see more: https://cloud.google.com/translate/docs/advanced/automl-upgrade. --- .../operators/cloud/translate.rst | 85 +++++ docs/spelling_wordlist.txt | 1 + .../providers/google/cloud/hooks/translate.py | 240 +++++++++++- .../providers/google/cloud/links/translate.py | 56 +++ .../google/cloud/operators/translate.py | 344 +++++++++++++++++- .../airflow/providers/google/provider.yaml | 2 + .../google/cloud/operators/test_translate.py | 205 ++++++++++- .../translate/example_translate_dataset.py | 153 ++++++++ 8 files changed, 1075 insertions(+), 11 deletions(-) create mode 100644 providers/tests/system/google/cloud/translate/example_translate_dataset.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst b/docs/apache-airflow-providers-google/operators/cloud/translate.rst index 6bcc32ec669ce..d56fac26dbeb4 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst @@ -100,11 +100,96 @@ For parameter definition, take a look at :class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator` +.. _howto/operator:TranslateCreateDatasetOperator: + +TranslateCreateDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_create_dataset] + :end-before: [END howto_operator_translate_automl_create_dataset] + + +.. _howto/operator:TranslateImportDataOperator: + +TranslateImportDataOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Import data to the existing native dataset, using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_import_data] + :end-before: [END howto_operator_translate_automl_import_data] + + +.. _howto/operator:TranslateDatasetsListOperator: + +TranslateDatasetsListOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Get list of translation datasets using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_list_datasets] + :end-before: [END howto_operator_translate_automl_list_datasets] + + +.. _howto/operator:TranslateDeleteDatasetOperator: + +TranslateDeleteDatasetOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Delete a native translation dataset using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteDatasetOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_dataset.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_delete_dataset] + :end-before: [END howto_operator_translate_automl_delete_dataset] + + More information """""""""""""""""" See: Base (V2) `Google Cloud Translate documentation `_. Advanced (V3) `Google Cloud Translate (Advanced) documentation `_. +Datasets `Legacy and native dataset comparison `_. Reference diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 20e10c44a12c6..fa8ffb4a2c3c6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -967,6 +967,7 @@ LineItem lineterminator linter linux +ListDatasetsPager ListGenerator ListInfoTypesResponse ListSecretsPager diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py b/providers/src/airflow/providers/google/cloud/hooks/translate.py index 51cb88f1bacef..6ddb220f3e789 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/translate.py +++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py @@ -29,6 +29,7 @@ from google.api_core.exceptions import GoogleAPICallError from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry from google.cloud.translate_v2 import Client from google.cloud.translate_v3 import TranslationServiceClient @@ -38,13 +39,31 @@ if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry + from google.cloud.translate_v3.services.translation_service import pagers from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) + from proto import Message + + +class WaitOperationNotDoneYetError(Exception): + """Wait operation not done yet error.""" + + pass + + +def _if_exc_is_wait_failed_error(exc: Exception): + return isinstance(exc, WaitOperationNotDoneYetError) + + +def _check_if_operation_done(operation: Operation): + if not operation.done(): + raise WaitOperationNotDoneYetError("Operation is not done yet.") class CloudTranslateHook(GoogleBaseHook): @@ -163,7 +182,42 @@ def get_client(self) -> TranslationServiceClient: return self._client @staticmethod - def wait_for_operation(operation: Operation, timeout: int | None = None): + def wait_for_operation_done( + *, + operation: Operation, + timeout: float | None = None, + initial: float = 3, + multiplier: float = 2, + maximum: float = 3600, + ) -> None: + """ + Wait for long-running operation to be done. + + Calls operation.done() until success or timeout exhaustion, following the back-off retry strategy. + See `google.api_core.retry.Retry`. + It's intended use on `Operation` instances that have empty result + (:class `google.protobuf.empty_pb2.Empty`) by design. + Thus calling operation.result() for such operation triggers the exception + ``GoogleAPICallError("Unexpected state: Long-running operation had neither response nor error set.")`` + even though operation itself is totally fine. + """ + wait_op_for_done = Retry( + predicate=_if_exc_is_wait_failed_error, + initial=initial, + timeout=timeout, + multiplier=multiplier, + maximum=maximum, + )(_check_if_operation_done) + try: + wait_op_for_done(operation=operation) + except GoogleAPICallError: + if timeout: + timeout = int(timeout) + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + @staticmethod + def wait_for_operation_result(operation: Operation, timeout: int | None = None) -> Message: """Wait for long-lasting operation to complete.""" try: return operation.result(timeout=timeout) @@ -171,6 +225,11 @@ def wait_for_operation(operation: Operation, timeout: int | None = None): error = operation.exception(timeout=timeout) raise AirflowException(error) + @staticmethod + def extract_object_id(obj: dict) -> str: + """Return unique id of the object.""" + return obj["name"].rpartition("/")[-1] + def translate_text( self, *, @@ -208,12 +267,10 @@ def translate_text( If not specified, 'global' is used. Non-global location is required for requests using AutoML models or custom glossaries. - Models and glossaries must be within the same region (have the same location-id). :param model: Optional. The ``model`` type requested for this translation. If not provided, the default Google model (NMT) will be used. - The format depends on model type: - AutoML Translation models: @@ -308,8 +365,8 @@ def batch_translate_text( :param timeout: The timeout for this request. :param metadata: Strings which should be sent along with the request as metadata. - :returns: Operation object with the batch text translate results, - that are returned by batches as they are ready. + :return: Operation object with the batch text translate results, + that are returned by batches as they are ready. """ client = self.get_client() if location == "global": @@ -334,3 +391,174 @@ def batch_translate_text( metadata=metadata, ) return result + + def create_dataset( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault | None = DEFAULT, + ) -> Operation: + """ + Create the translation dataset. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the dataset to be created. + """ + client = self.get_client() + parent = f"projects/{project_id or self.project_id}/locations/{location}" + return client.create_dataset( + request={"parent": parent, "dataset": dataset}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def get_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> automl_translation.Dataset: + """ + Retrieve the dataset for the given dataset_id. + + :param dataset_id: ID of translation dataset to be retrieved. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `automl_translation.Dataset` instance. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + return client.get_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def import_dataset_data( + self, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Import data into the translation dataset. + + :param dataset_id: ID of the translation dataset. + :param input_config: The desired input location and its domain specific semantics, if any. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located if None then + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object for the import data. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.import_data( + request={"dataset": name, "input_config": input_config}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def list_datasets( + self, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> pagers.ListDatasetsPager: + """ + List translation datasets in a project. + + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: ``pagers.ListDatasetsPager`` instance, iterable object to retrieve the datasets list. + """ + client = self.get_client() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_datasets( + request={"parent": parent}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def delete_dataset( + self, + dataset_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete the translation dataset and all of its contents. + + :param dataset_id: ID of dataset to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object with dataset deletion results, when finished. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.delete_dataset( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py b/providers/src/airflow/providers/google/cloud/links/translate.py index d8cbd18d00dad..0d1489ddcfc81 100644 --- a/providers/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/src/airflow/providers/google/cloud/links/translate.py @@ -45,6 +45,11 @@ TRANSLATION_TRANSLATE_TEXT_BATCH = BASE_LINK + "/storage/browser/{output_uri_prefix}?project={project_id}" +TRANSLATION_NATIVE_DATASET_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/sentences?project={project_id}" +) +TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" + class TranslationLegacyDatasetLink(BaseGoogleLink): """ @@ -214,3 +219,54 @@ def persist( "output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config), }, ) + + +class TranslationNativeDatasetLink(BaseGoogleLink): + """ + Helper class for constructing Legacy Translation Dataset link. + + Legacy Datasets are created and managed by AutoML API. + """ + + name = "Translation Native Dataset" + key = "translation_naive_dataset" + format_str = TRANSLATION_NATIVE_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationNativeDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class TranslationDatasetsListLink(BaseGoogleLink): + """ + Helper class for constructing Translation Datasets List link. + + Both legacy and native datasets are available under this link. + """ + + name = "Translation Dataset List" + key = "translation_dataset_list" + format_str = TRANSLATION_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationDatasetsListLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/translate.py b/providers/src/airflow/providers/google/cloud/operators/translate.py index a0fa9243e01a4..d384e9b8efa9b 100644 --- a/providers/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/src/airflow/providers/google/cloud/operators/translate.py @@ -26,17 +26,23 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook, TranslateHook -from airflow.providers.google.cloud.links.translate import TranslateTextBatchLink +from airflow.providers.google.cloud.links.translate import ( + TranslateTextBatchLink, + TranslationDatasetsListLink, + TranslationNativeDatasetLink, +) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: from google.api_core.retry import Retry from google.cloud.translate_v3.types import ( + DatasetInputConfig, InputConfig, OutputConfig, TranslateTextGlossaryConfig, TransliterationConfig, + automl_translation, ) from airflow.utils.context import Context @@ -266,7 +272,7 @@ class TranslateTextBatchOperator(GoogleCloudBaseOperator): See https://cloud.google.com/translate/docs/advanced/batch-translation For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:TranslateTextBatchOperator`. + :ref:`howto/operator:TranslateTextBatchOperator`. :param project_id: Optional. The ID of the Google Cloud project that the service belongs to. If not specified the hook project_id will be used. @@ -381,6 +387,338 @@ def execute(self, context: Context) -> dict: project_id=self.project_id or hook.project_id, output_config=self.output_config, ) - hook.wait_for_operation(translate_operation) + hook.wait_for_operation_result(translate_operation) self.log.info("Translate text batch job finished") return {"batch_text_translate_results": self.output_config["gcs_destination"]} + + +class TranslateCreateDatasetOperator(GoogleCloudBaseOperator): + """ + Create a Google Cloud Translate dataset. + + Creates a `native` translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateCreateDatasetOperator`. + + :param dataset: The dataset to create. If a dict is provided, it must correspond to + the automl_translation.Dataset type. + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset: dict | automl_translation.Dataset, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault | None = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.dataset = dataset + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> str: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Dataset creation started %s...", self.dataset) + result_operation = hook.create_dataset( + dataset=self.dataset, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result = hook.wait_for_operation_result(result_operation) + result = type(result).to_dict(result) + dataset_id = hook.extract_object_id(result) + self.xcom_push(context, key="dataset_id", value=dataset_id) + self.log.info("Dataset creation complete. The dataset_id: %s.", dataset_id) + + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + ) + return result + + +class TranslateDatasetsListOperator(GoogleCloudBaseOperator): + """ + Get a list of native Google Cloud Translation datasets in a project. + + Get project's list of `native` translation datasets, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDatasetsListOperator`. + + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationDatasetsListLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + project_id = self.project_id or hook.project_id + TranslationDatasetsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + self.log.info("Requesting datasets list") + results_pager = hook.list_datasets( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result_ids = [] + for ds_item in results_pager: + ds_data = type(ds_item).to_dict(ds_item) + ds_id = hook.extract_object_id(ds_data) + result_ids.append(ds_id) + + self.log.info("Fetching the datasets list complete.") + return result_ids + + +class TranslateImportDataOperator(GoogleCloudBaseOperator): + """ + Import data to the translation dataset. + + Loads data to the translation dataset, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateImportDataOperator`. + + :param dataset_id: The dataset_id of target native dataset to import data to. + :param input_config: The desired input location of translations language pairs file. If a dict provided, + must follow the structure of DatasetInputConfig. + If a dict is provided, it must be of the same form as the protobuf message InputConfig. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "input_config", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationNativeDatasetLink(),) + + def __init__( + self, + *, + dataset_id: str, + location: str, + input_config: dict | DatasetInputConfig, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.input_config = input_config + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Importing data to dataset...") + operation = hook.import_dataset_data( + dataset_id=self.dataset_id, + input_config=self.input_config, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + project_id = self.project_id or hook.project_id + TranslationNativeDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + project_id=project_id, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Importing data finished!") + + +class TranslateDeleteDatasetOperator(GoogleCloudBaseOperator): + """ + Delete translation dataset and all of its contents. + + Deletes the translation dataset and it's data, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDeleteDatasetOperator`. + + :param dataset_id: The dataset_id of target native dataset to be deleted. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + dataset_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dataset_id = dataset_id + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting the dataset %s...", self.dataset_id) + operation = hook.delete_dataset( + dataset_id=self.dataset_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Dataset deletion complete!") diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index e1d2759124f55..88f1021d5969d 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -1292,6 +1292,8 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink - airflow.providers.google.cloud.links.translate.TranslateTextBatchLink + - airflow.providers.google.cloud.links.translate.TranslationNativeDatasetLink + - airflow.providers.google.cloud.links.translate.TranslationDatasetsListLink secrets-backends: diff --git a/providers/tests/google/cloud/operators/test_translate.py b/providers/tests/google/cloud/operators/test_translate.py index 79f65395369b8..45af2dae92890 100644 --- a/providers/tests/google/cloud/operators/test_translate.py +++ b/providers/tests/google/cloud/operators/test_translate.py @@ -19,15 +19,27 @@ from unittest import mock +from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.translate_v3.types import automl_translation + +from airflow.providers.google.cloud.hooks.translate import TranslateHook from airflow.providers.google.cloud.operators.translate import ( CloudTranslateTextOperator, + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, TranslateTextBatchOperator, TranslateTextOperator, ) +from providers.tests.system.google.cloud.tasks.example_tasks import LOCATION + GCP_CONN_ID = "google_cloud_default" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] PROJECT_ID = "test-project-id" +DATASET_ID = "sample_ds_id" +TIMEOUT_VALUE = 30 class TestCloudTranslate: @@ -97,7 +109,7 @@ def test_minimal_green_path(self, mock_hook): target_language_code="en", gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, model=None, ) @@ -117,7 +129,7 @@ def test_minimal_green_path(self, mock_hook): model=None, transliteration_config=None, glossary_config=None, - timeout=30, + timeout=TIMEOUT_VALUE, retry=None, metadata=(), ) @@ -185,3 +197,192 @@ def test_minimal_green_path(self, mock_hook, mock_link_persist): project_id=PROJECT_ID, output_config=OUTPUT_CONFIG, ) + + +class TestTranslateDatasetCreate: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator.xcom_push") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + DS_CREATION_RESULT_SAMPLE = { + "display_name": "", + "example_count": 0, + "name": f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}", + "source_language_code": "", + "target_language_code": "", + "test_example_count": 0, + "train_example_count": 0, + "validate_example_count": 0, + } + sample_operation = mock.MagicMock() + sample_operation.result.return_value = automl_translation.Dataset(DS_CREATION_RESULT_SAMPLE) + + mock_hook.return_value.create_dataset.return_value = sample_operation + mock_hook.return_value.wait_for_operation_result.side_effect = lambda operation: operation.result() + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + DATASET_DATA = { + "display_name": "sample ds name", + "source_language_code": "es", + "target_language_code": "uk", + } + op = TranslateCreateDatasetOperator( + task_id="task_id", + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=None, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.create_dataset.assert_called_once_with( + dataset=DATASET_DATA, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=None, + metadata=(), + ) + mock_xcom_push.assert_called_once_with(context, key="dataset_id", value=DATASET_ID) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == DS_CREATION_RESULT_SAMPLE + + +class TestTranslateListDatasets: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationDatasetsListLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + DS_ID_1 = "sample_ds_1" + DS_ID_2 = "sample_ds_2" + dataset_result_1 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_1}", + display_name="ds1_display_name", + ) + ) + dataset_result_2 = automl_translation.Dataset( + dict( + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DS_ID_2}", + display_name="ds1_display_name", + ) + ) + mock_hook.return_value.list_datasets.return_value = [dataset_result_1, dataset_result_2] + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + op = TranslateDatasetsListOperator( + task_id="task_id", + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list_datasets.assert_called_once_with( + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + project_id=PROJECT_ID, + ) + assert result == [DS_ID_1, DS_ID_2] + + +class TestTranslateImportData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslationNativeDatasetLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + INPUT_CONFIG = { + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": "import data gcs path"}}] + } + mock_hook.return_value.import_dataset_data.return_value = mock.MagicMock() + op = TranslateImportDataOperator( + task_id="task_id", + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.import_dataset_data.assert_called_once_with( + dataset_id=DATASET_ID, + input_config=INPUT_CONFIG, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + mock_link_persist.assert_called_once_with( + context=context, + dataset_id=DATASET_ID, + task_instance=op, + project_id=PROJECT_ID, + ) + + +class TestTranslateDeleteData: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook): + m_delete_method_result = mock.MagicMock() + mock_hook.return_value.delete_dataset.return_value = m_delete_method_result + + wait_for_done = mock_hook.return_value.wait_for_operation_done + + op = TranslateDeleteDatasetOperator( + task_id="task_id", + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.delete_dataset.assert_called_once_with( + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + wait_for_done.assert_called_once_with(operation=m_delete_method_result, timeout=TIMEOUT_VALUE) diff --git a/providers/tests/system/google/cloud/translate/example_translate_dataset.py b/providers/tests/system/google/cloud/translate/example_translate_dataset.py new file mode 100644 index 0000000000000..3ad732862449d --- /dev/null +++ b/providers/tests/system/google/cloud/translate/example_translate_dataset.py @@ -0,0 +1,153 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that translates text in Google Cloud Translate using V3 API version +service in the Google Cloud. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.operators.translate import ( + TranslateCreateDatasetOperator, + TranslateDatasetsListOperator, + TranslateDeleteDatasetOperator, + TranslateImportDataOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "gcp_translate_automl_native_dataset" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +REGION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" + +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") + +DATA_FILE_NAME = "import_en-es.tsv" + +RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}" +COPY_DATA_PATH = f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}" + +DST_PATH = f"translate/import/{DATA_FILE_NAME}" + +DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}" + + +DATASET = { + "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}", + "source_language_code": "es", + "target_language_code": "en", +} + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2024, 11, 1), + catchup=False, + tags=[ + "example", + "translate_dataset", + ], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + copy_dataset_source_tsv = GCSToGCSOperator( + task_id="copy_dataset_file", + source_bucket=RESOURCE_DATA_BUCKET, + source_object=RESOURCE_PATH, + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object=DST_PATH, + ) + + # [START howto_operator_translate_automl_create_dataset] + create_dataset_op = TranslateCreateDatasetOperator( + task_id="translate_v3_ds_create", + dataset=DATASET, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_create_dataset] + + # [START howto_operator_translate_automl_import_data] + import_ds_data_op = TranslateImportDataOperator( + task_id="translate_v3_ds_import_data", + dataset_id=create_dataset_op.output["dataset_id"], + input_config={ + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": DATASET_DATA_PATH}}] + }, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_import_data] + + # [START howto_operator_translate_automl_list_datasets] + list_datasets_op = TranslateDatasetsListOperator( + task_id="translate_v3_list_ds", + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_list_datasets] + + # [START howto_operator_translate_automl_delete_dataset] + delete_ds_op = TranslateDeleteDatasetOperator( + task_id="translate_v3_ds_delete", + dataset_id=create_dataset_op.output["dataset_id"], + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + [create_bucket >> copy_dataset_source_tsv] + # TEST BODY + >> create_dataset_op + >> import_ds_data_op + >> list_datasets_op + >> delete_ds_op + # TEST TEARDOWN + >> delete_bucket + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)