diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst b/docs/apache-airflow-providers-google/operators/cloud/translate.rst index 579236cb0883c..6bcc32ec669ce 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst @@ -18,7 +18,7 @@ Google Cloud Translate Operators --------------------------------- +======================================= Prerequisite Tasks ^^^^^^^^^^^^^^^^^^ @@ -41,19 +41,19 @@ Using the operator Basic usage of the operator: .. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate.py - :language: python - :dedent: 4 - :start-after: [START howto_operator_translate_text] - :end-before: [END howto_operator_translate_text] + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_text] + :end-before: [END howto_operator_translate_text] The result of translation is available as dictionary or array of dictionaries accessible via the usual XCom mechanisms of Airflow: .. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate.py - :language: python - :dedent: 4 - :start-after: [START howto_operator_translate_access] - :end-before: [END howto_operator_translate_access] + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_access] + :end-before: [END howto_operator_translate_access] Templating @@ -65,10 +65,47 @@ Templating :start-after: [START translate_template_fields] :end-before: [END translate_template_fields] +.. _howto/operator:TranslateTextOperator: + +TranslateTextOperator +^^^^^^^^^^^^^^^^^^^^^ + +Translate an array of one or more text (or html) items. +Intended to use for moderate amount of text data, for large volumes please use the +:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator` + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_text.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_text_advanced] + :end-before: [END howto_operator_translate_text_advanced] + + +.. _howto/operator:TranslateTextBatchOperator: + +TranslateTextBatchOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^ +Translate large amount of text data into up to 10 target languages in a single run. +List of files and other options provided by input configuration. + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator` + + More information -"""""""""""""""" +"""""""""""""""""" +See: +Base (V2) `Google Cloud Translate documentation `_. +Advanced (V3) `Google Cloud Translate (Advanced) documentation `_. -See `Google Cloud Translate documentation `_. Reference ^^^^^^^^^ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index b2df194de0dfd..20e10c44a12c6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1702,6 +1702,7 @@ traceback tracebacks tracemalloc TrainingPipeline +TranslationServiceClient travis triage triaging diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py b/providers/src/airflow/providers/google/cloud/hooks/translate.py index 04a309678a61b..51cb88f1bacef 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/translate.py +++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py @@ -19,12 +19,32 @@ from __future__ import annotations -from typing import Sequence +from typing import ( + TYPE_CHECKING, + MutableMapping, + MutableSequence, + Sequence, + cast, +) +from google.api_core.exceptions import GoogleAPICallError +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.translate_v2 import Client +from google.cloud.translate_v3 import TranslationServiceClient +from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook + +if TYPE_CHECKING: + from google.api_core.operation import Operation + from google.api_core.retry import Retry + from google.cloud.translate_v3.types import ( + InputConfig, + OutputConfig, + TranslateTextGlossaryConfig, + TransliterationConfig, + ) class CloudTranslateHook(GoogleBaseHook): @@ -81,7 +101,7 @@ def translate( :param source_language: (Optional) The language of the text to be translated. :param model: (Optional) The model used to translate the text, such - as ``'base'`` or ``'nmt'``. + as ``'base'`` or ``'NMT'``. :returns: A list of dictionaries for each queried value. Each dictionary typically contains three keys (though not all will be present in all cases) @@ -102,7 +122,6 @@ def translate( values and translations differ. """ client = self.get_conn() - return client.translate( values=values, target_language=target_language, @@ -110,3 +129,208 @@ def translate( source_language=source_language, model=model, ) + + +class TranslateHook(GoogleBaseHook): + """ + Hook for Google Cloud translation (Advanced) using client version V3. + + See related docs https://cloud.google.com/translate/docs/editions#advanced. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + ) + self._client: TranslationServiceClient | None = None + + def get_client(self) -> TranslationServiceClient: + """ + Retrieve TranslationService client. + + :return: Google Cloud Translation Service client object. + """ + if self._client is None: + self._client = TranslationServiceClient( + credentials=self.get_credentials(), client_info=CLIENT_INFO + ) + return self._client + + @staticmethod + def wait_for_operation(operation: Operation, timeout: int | None = None): + """Wait for long-lasting operation to complete.""" + try: + return operation.result(timeout=timeout) + except GoogleAPICallError: + error = operation.exception(timeout=timeout) + raise AirflowException(error) + + def translate_text( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + contents: Sequence[str], + target_language_code: str, + source_language_code: str | None = None, + mime_type: str | None = None, + location: str | None = None, + model: str | None = None, + transliteration_config: TransliterationConfig | None = None, + glossary_config: TranslateTextGlossaryConfig | None = None, + labels: str | None = None, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault | None = DEFAULT, + ) -> dict: + """ + Translate text content provided. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param contents: Required. The content of the input in string + format. Max length 1024 items with 30_000 codepoints recommended. + :param mime_type: Optional. The format of the source text, If left + blank, the MIME type defaults to "text/html". + :param source_language_code: Optional. The ISO-639 language code of the + input text if known. If the source language + isn't specified, the API attempts to identify + the source language automatically and returns + the source language within the response. + :param target_language_code: Required. The ISO-639 language code to use + for translation of the input text + :param location: Optional. Project or location to make a call. Must refer to + a caller's project. + If not specified, 'global' is used. + Non-global location is required for requests using AutoML + models or custom glossaries. + + Models and glossaries must be within the same region (have + the same location-id). + :param model: Optional. The ``model`` type requested for this translation. + If not provided, the default Google model (NMT) will be used. + + The format depends on model type: + + - AutoML Translation models: + ``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}`` + - General (built-in) models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt`` + - Translation LLM models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/translation-llm`` + + For global (no region) requests, use ``location-id`` ``global``. + For example, ``projects/{project-number-or-id}/locations/global/models/general/nmt``. + :param glossary_config: Optional. Glossary to be applied. The glossary must be + within the same region (have the same location-id) as the + model. + :param transliteration_config: Optional. Transliteration to be applied. + :param labels: Optional. The labels with user-defined + metadata for the request. + See https://cloud.google.com/translate/docs/advanced/labels for more information. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + + :return: Translate text result from the API response. + """ + client = self.get_client() + location_id = "global" if not location else location + parent = f"projects/{project_id or self.project_id}/locations/{location_id}" + + result = client.translate_text( + request={ + "parent": parent, + "source_language_code": source_language_code, + "target_language_code": target_language_code, + "contents": contents, + "mime_type": mime_type, + "glossary_config": glossary_config, + "transliteration_config": transliteration_config, + "model": model, + "labels": labels, + }, + timeout=timeout, + retry=retry, + metadata=metadata, + ) + return cast(dict, type(result).to_dict(result)) + + def batch_translate_text( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + source_language_code: str, + target_language_codes: MutableSequence[str], + input_configs: MutableSequence[InputConfig | dict], + output_config: OutputConfig | dict, + models: str | None = None, + glossaries: MutableMapping[str, TranslateTextGlossaryConfig] | None = None, + labels: MutableMapping[str, str] | None = None, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + retry: Retry | _MethodDefault | None = DEFAULT, + ) -> Operation: + """ + Translate large volumes of text data. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Optional. Project or location to make a call. Must refer to + a caller's project. Must be non-global. + :param source_language_code: Required. Source language code. + :param target_language_codes: Required. Specify up to 10 language codes here. + :param models: Optional. The models to use for translation. Map's key is + target language code. Map's value is model name. Value can + be a built-in general model, or an AutoML Translation model. + The value format depends on model type: + + - AutoML Translation models: + ``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}`` + - General (built-in) models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt`` + + If the map is empty or a specific model is not requested for + a language pair, then the default Google model (NMT) is used. + :param input_configs: Required. Input configurations. + The total number of files matched should be <= 100. The total content size should be <= 100M + Unicode codepoints. The files must use UTF-8 encoding. + :param output_config: Required. Output configuration. + :param glossaries: Optional. Glossaries to be applied for + translation. It's keyed by target language code. + :param labels: Optional. The labels with user-defined metadata for the request. + See https://cloud.google.com/translate/docs/advanced/labels for more information. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + + :returns: Operation object with the batch text translate results, + that are returned by batches as they are ready. + """ + client = self.get_client() + if location == "global": + raise AirflowException( + "Global location is not allowed for the batch text translation, " + "please provide the correct value!" + ) + parent = f"projects/{project_id or self.project_id}/locations/{location}" + result = client.batch_translate_text( + request={ + "parent": parent, + "source_language_code": source_language_code, + "target_language_codes": target_language_codes, + "input_configs": input_configs, + "output_config": output_config, + "glossaries": glossaries, + "models": models, + "labels": labels, + }, + timeout=timeout, + retry=retry, + metadata=metadata, + ) + return result diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py b/providers/src/airflow/providers/google/cloud/links/translate.py index b770570fe443f..d8cbd18d00dad 100644 --- a/providers/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/src/airflow/providers/google/cloud/links/translate.py @@ -43,6 +43,8 @@ + "/locations/{location}/datasets/{dataset_id}/predict;modelId={model_id}?project={project_id}" ) +TRANSLATION_TRANSLATE_TEXT_BATCH = BASE_LINK + "/storage/browser/{output_uri_prefix}?project={project_id}" + class TranslationLegacyDatasetLink(BaseGoogleLink): """ @@ -179,3 +181,36 @@ def persist( "project_id": project_id, }, ) + + +class TranslateTextBatchLink(BaseGoogleLink): + """ + Helper class for constructing Translation results for the text batch translate. + + Provides link to output results. + + """ + + name = "Text Translate Batch" + key = "translate_text_batch" + format_str = TRANSLATION_TRANSLATE_TEXT_BATCH + + @staticmethod + def extract_output_uri_prefix(output_config): + return output_config["gcs_destination"]["output_uri_prefix"].rpartition("gs://")[-1] + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + output_config: dict, + ): + task_instance.xcom_push( + context, + key=TranslateTextBatchLink.key, + value={ + "project_id": project_id, + "output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config), + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/translate.py b/providers/src/airflow/providers/google/cloud/operators/translate.py index 2a17ebcafa91f..a0fa9243e01a4 100644 --- a/providers/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/src/airflow/providers/google/cloud/operators/translate.py @@ -19,13 +19,26 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, MutableMapping, MutableSequence, Sequence + +from google.api_core.exceptions import GoogleAPICallError +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from airflow.exceptions import AirflowException -from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook +from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook, TranslateHook +from airflow.providers.google.cloud.links.translate import TranslateTextBatchLink from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: + from google.api_core.retry import Retry + from google.cloud.translate_v3.types import ( + InputConfig, + OutputConfig, + TranslateTextGlossaryConfig, + TransliterationConfig, + ) + from airflow.utils.context import Context @@ -42,43 +55,27 @@ class CloudTranslateTextOperator(GoogleCloudBaseOperator): Execute method returns str or list. This is a list of dictionaries for each queried value. Each - dictionary typically contains three keys (though not - all will be present in all cases). + dictionary typically contains three keys (though not all will be present in all cases): - * ``detectedSourceLanguage``: The detected language (as an - ISO 639-1 language code) of the text. - * ``translatedText``: The translation of the text into the - target language. + * ``detectedSourceLanguage``: The detected language (as an ISO 639-1 language code) of the text. + * ``translatedText``: The translation of the text into the target language. * ``input``: The corresponding input value. * ``model``: The model used to translate the text. If only a single value is passed, then only a single - dictionary is set as XCom return value. + dictionary is set as the XCom return value. :param values: String or list of strings to translate. - - :param target_language: The language to translate results into. This - is required by the API and defaults to - the target language of the current instance. - - :param format_: (Optional) One of ``text`` or ``html``, to specify - if the input text is plain text or HTML. - - :param source_language: (Optional) The language of the text to - be translated. - - :param model: (Optional) The model used to translate the text, such - as ``'base'`` or ``'nmt'``. - - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - + :param target_language: The language to translate results into. This is required by the API. + :param format_: (Optional) One of ``text`` or ``html``, to specify if the input text is plain text or HTML. + :param source_language: (Optional) The language of the text to be translated. + :param model: (Optional) The model used to translate the text, such as ``'base'`` or ``'nmt'``. + :param impersonation_chain: Optional service account to impersonate using short-term credentials, or + chained list of accounts required to get the access_token of the last account in the list, which + will be impersonated in the request. If set as a string, the account must grant the originating + account the Service Account Token Creator IAM role. If set as a sequence, the identities from + the list must grant Service Account Token Creator IAM role to the directly preceding identity, + with the first account from the list granting this role to the originating account (templated). """ # [START translate_template_fields] @@ -133,3 +130,257 @@ def execute(self, context: Context) -> dict: self.log.error("An error has been thrown from translate method:") self.log.error(e) raise AirflowException(e) + + +class TranslateTextOperator(GoogleCloudBaseOperator): + """ + Translate text content of moderate amount, for larger volumes of text please use the TranslateTextBatchOperator. + + Wraps the Google cloud Translate Text (Advanced) functionality. + See https://cloud.google.com/translate/docs/advanced/translating-text-v3 + + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateTextOperator`. + + :param project_id: Optional. The ID of the Google Cloud project that the + service belongs to. + :param location: optional. The ID of the Google Cloud location that the + service belongs to. if not specified, 'global' is used. + Non-global location is required for requests using AutoML models or custom glossaries. + :param contents: Required. The sequence of content strings to be translated. + Limited to 1024 items with 30_000 codepoints total recommended. + :param mime_type: Optional. The format of the source text, If left blank, + the MIME type defaults to "text/html". + :param source_language_code: Optional. The ISO-639 language code of the + input text if known. If not specified, attempted to recognize automatically. + :param target_language_code: Required. The ISO-639 language code to use + for translation of the input text. + :param model: Optional. The ``model`` type requested for this translation. + If not provided, the default Google model (NMT) will be used. + The format depends on model type: + + - AutoML Translation models: + ``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}`` + - General (built-in) models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt`` + - Translation LLM models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/translation-llm`` + + For global (non-region) requests, use 'global' ``location-id``. + :param glossary_config: Optional. Glossary to be applied. + :param transliteration_config: Optional. Transliteration to be applied. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "contents", + "target_language_code", + "mime_type", + "source_language_code", + "model", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + contents: Sequence[str], + source_language_code: str | None = None, + target_language_code: str, + mime_type: str | None = None, + location: str | None = None, + project_id: str = PROVIDE_PROJECT_ID, + model: str | None = None, + transliteration_config: TransliterationConfig | None = None, + glossary_config: TranslateTextGlossaryConfig | None = None, + labels: str | None = None, + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault | None = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.contents = contents + self.source_language_code = source_language_code + self.target_language_code = target_language_code + self.mime_type = mime_type + self.location = location + self.labels = labels + self.model = model + self.transliteration_config = transliteration_config + self.glossary_config = glossary_config + self.metadate = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + try: + self.log.info("Starting the text translation run") + translation_result = hook.translate_text( + contents=self.contents, + source_language_code=self.source_language_code, + target_language_code=self.target_language_code, + mime_type=self.mime_type, + location=self.location, + labels=self.labels, + model=self.model, + transliteration_config=self.transliteration_config, + glossary_config=self.glossary_config, + timeout=self.timeout, + retry=self.retry, + metadata=self.metadate, + ) + self.log.info("Text translation run complete") + return translation_result + except GoogleAPICallError as e: + self.log.error("An error occurred executing translate_text method: \n%s", e) + raise AirflowException(e) + + +class TranslateTextBatchOperator(GoogleCloudBaseOperator): + """ + Translate large volumes of text content, by the inputs provided. + + Wraps the Google cloud Translate Text (Advanced) functionality. + See https://cloud.google.com/translate/docs/advanced/batch-translation + + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateTextBatchOperator`. + + :param project_id: Optional. The ID of the Google Cloud project that the + service belongs to. If not specified the hook project_id will be used. + :param location: required. The ID of the Google Cloud location, (non-global) that the + service belongs to. + :param source_language_code: Required. Source language code. + :param target_language_codes: Required. Up to 10 language codes allowed here. + :param input_configs: Required. Input configurations. + The total number of files matched should be <=100. The total content size should be <= 100M Unicode codepoints. + The files must use UTF-8 encoding. + :param models: Optional. The models to use for translation. Map's key is + target language code. Map's value is model name. Value can + be a built-in general model, or an AutoML Translation model. + The value format depends on model type: + + - AutoML Translation models: + ``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}`` + - General (built-in) models: + ``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt`` + + If the map is empty or a specific model is not requested for + a language pair, then the default Google model (NMT) is used. + :param output_config: Required. Output configuration. + :param glossaries: Optional. Glossaries to be applied for translation. It's keyed by target language code. + :param labels: Optional. The labels with user-defined metadata. + See https://cloud.google.com/translate/docs/advanced/labels for more information. + + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + operator_extra_links = (TranslateTextBatchLink(),) + + template_fields: Sequence[str] = ( + "input_configs", + "target_language_codes", + "source_language_code", + "models", + "glossaries", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + target_language_codes: MutableSequence[str], + source_language_code: str, + input_configs: MutableSequence[InputConfig | dict], + output_config: OutputConfig | dict, + models: str | None = None, + glossaries: MutableMapping[str, TranslateTextGlossaryConfig] | None = None, + labels: MutableMapping[str, str] | None = None, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault | None = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.target_language_codes = target_language_codes + self.source_language_code = source_language_code + self.input_configs = input_configs + self.output_config = output_config + self.models = models + self.glossaries = glossaries + self.labels = labels + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> dict: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + translate_operation = hook.batch_translate_text( + project_id=self.project_id, + location=self.location, + target_language_codes=self.target_language_codes, + source_language_code=self.source_language_code, + input_configs=self.input_configs, + output_config=self.output_config, + models=self.models, + glossaries=self.glossaries, + labels=self.labels, + metadata=self.metadata, + timeout=self.timeout, + retry=self.retry, + ) + self.log.info("Translate text batch job started.") + TranslateTextBatchLink.persist( + context=context, + task_instance=self, + project_id=self.project_id or hook.project_id, + output_config=self.output_config, + ) + hook.wait_for_operation(translate_operation) + self.log.info("Translate text batch job finished") + return {"batch_text_translate_results": self.output_config["gcs_destination"]} diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 6c3350a7ba4f0..e1d2759124f55 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -1291,6 +1291,7 @@ extra-links: - airflow.providers.google.cloud.links.translate.TranslationLegacyModelLink - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink + - airflow.providers.google.cloud.links.translate.TranslateTextBatchLink secrets-backends: diff --git a/providers/tests/google/cloud/hooks/test_translate.py b/providers/tests/google/cloud/hooks/test_translate.py index 42429bc58ddae..f4e7f14e0e748 100644 --- a/providers/tests/google/cloud/hooks/test_translate.py +++ b/providers/tests/google/cloud/hooks/test_translate.py @@ -19,7 +19,9 @@ from unittest import mock -from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook +from google.cloud.translate_v3.types import TranslateTextResponse + +from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook, TranslateHook from airflow.providers.google.common.consts import CLIENT_INFO from providers.tests.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id @@ -75,3 +77,119 @@ def test_translate_called(self, get_conn): source_language=None, model="base", ) + + +class TestTranslateHook: + def setup_method(self): + with mock.patch( + "airflow.providers.google.cloud.hooks.translate.TranslateHook.__init__", + new=mock_base_gcp_hook_default_project_id, + ): + self.hook = TranslateHook(gcp_conn_id="test") + + @mock.patch("airflow.providers.google.cloud.hooks.translate.TranslateHook.get_credentials") + @mock.patch("airflow.providers.google.cloud.hooks.translate.TranslationServiceClient") + def test_translate_client_creation(self, mock_client, mock_get_creds): + result = self.hook.get_client() + mock_client.assert_called_once_with(credentials=mock_get_creds.return_value, client_info=CLIENT_INFO) + assert mock_client.return_value == result + assert self.hook._client == result + + @mock.patch("airflow.providers.google.cloud.hooks.translate.TranslateHook.get_client") + def test_translate_text_method(self, get_client): + translation_result_data = { + "translations": [ + {"translated_text": "Hello World!", "model": "", "detected_language_code": ""}, + { + "translated_text": "Can you get me a cup of coffee, please?", + "model": "", + "detected_language_code": "", + }, + ], + "glossary_translations": [], + } + data_to_translate = ["Ciao mondo!", "Mi puoi prendere una tazza di caffè, per favore?"] + translate_client = get_client.return_value + translate_text_client_method = translate_client.translate_text + translate_text_client_method.return_value = TranslateTextResponse(translation_result_data) + + input_translation_args = dict( + project_id=PROJECT_ID_TEST, + contents=data_to_translate, + source_language_code="it", + target_language_code="en", + mime_type="text/plain", + location="global", + glossary_config=None, + transliteration_config=None, + model=None, + labels=None, + metadata=(), + timeout=30, + retry=None, + ) + result = self.hook.translate_text(**input_translation_args) + assert result == translation_result_data + + expected_call_args = { + "request": { + "parent": f"projects/{PROJECT_ID_TEST}/locations/global", + "contents": data_to_translate, + "source_language_code": "it", + "target_language_code": "en", + "mime_type": "text/plain", + "glossary_config": None, + "transliteration_config": None, + "model": None, + "labels": None, + }, + "retry": None, + "metadata": (), + "timeout": 30, + } + translate_text_client_method.assert_called_once_with(**expected_call_args) + + @mock.patch("airflow.providers.google.cloud.hooks.translate.TranslateHook.get_client") + def test_batch_translate_text_method(self, get_client): + sample_method_result = "batch_translate_api_call_result_obj" + translate_client = get_client.return_value + translate_text_client_method = translate_client.batch_translate_text + translate_text_client_method.return_value = sample_method_result + BATCH_TRANSLATE_INPUT = { + "gcs_source": {"input_uri": "input_source_uri"}, + "mime_type": "text/plain", + } + GCS_OUTPUT_DST = {"gcs_destination": {"output_uri_prefix": "translate_output_uri_prefix"}} + LOCATION = "us-central1" + input_translation_args = dict( + project_id=PROJECT_ID_TEST, + source_language_code="de", + target_language_codes=["en", "uk"], + location=LOCATION, + input_configs=[BATCH_TRANSLATE_INPUT], + output_config=GCS_OUTPUT_DST, + glossaries=None, + models=None, + labels=None, + metadata=(), + timeout=30, + retry=None, + ) + result = self.hook.batch_translate_text(**input_translation_args) + expected_call_args = { + "request": { + "parent": f"projects/{PROJECT_ID_TEST}/locations/{LOCATION}", + "source_language_code": "de", + "target_language_codes": ["en", "uk"], + "input_configs": [BATCH_TRANSLATE_INPUT], + "output_config": GCS_OUTPUT_DST, + "glossaries": None, + "models": None, + "labels": None, + }, + "retry": None, + "metadata": (), + "timeout": 30, + } + translate_text_client_method.assert_called_once_with(**expected_call_args) + assert result == sample_method_result diff --git a/providers/tests/google/cloud/operators/test_translate.py b/providers/tests/google/cloud/operators/test_translate.py index 7f2fd7983d5b8..79f65395369b8 100644 --- a/providers/tests/google/cloud/operators/test_translate.py +++ b/providers/tests/google/cloud/operators/test_translate.py @@ -19,10 +19,15 @@ from unittest import mock -from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator +from airflow.providers.google.cloud.operators.translate import ( + CloudTranslateTextOperator, + TranslateTextBatchOperator, + TranslateTextOperator, +) GCP_CONN_ID = "google_cloud_default" IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] +PROJECT_ID = "test-project-id" class TestCloudTranslate: @@ -67,3 +72,116 @@ def test_minimal_green_path(self, mock_hook): "input": "zażółć gęślą jaźń", } ] == return_value + + +class TestTranslateText: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook): + translation_result_data = { + "translations": [ + {"translated_text": "Hello World!", "model": "", "detected_language_code": ""}, + { + "translated_text": "Can you get me a cup of coffee, please?", + "model": "", + "detected_language_code": "", + }, + ], + "glossary_translations": [], + } + mock_hook.return_value.translate_text.return_value = translation_result_data + data_to_translate = ["Ciao mondo!", "Mi puoi prendere una tazza di caffè, per favore?"] + op = TranslateTextOperator( + task_id="task_id", + contents=data_to_translate, + source_language_code="it", + target_language_code="en", + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=30, + retry=None, + model=None, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.translate_text.assert_called_once_with( + contents=data_to_translate, + source_language_code="it", + target_language_code="en", + mime_type=None, + location=None, + labels=None, + model=None, + transliteration_config=None, + glossary_config=None, + timeout=30, + retry=None, + metadata=(), + ) + assert translation_result_data == result + + +class TestTranslateTextBatchOperator: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslateTextBatchLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + input_config_item = { + "gcs_source": {"input_uri": "gs://source_bucket_uri/sample_data_src_lang.txt"}, + "mime_type": "text/plain", + } + SRC_LANG_CODE = "src_lang_code" + TARGET_LANG_CODES = ["target_lang_code1", "target_lang_code2"] + LOCATION = "location-id" + TIMEOUT = 30 + INPUT_CONFIGS = [input_config_item] + OUTPUT_CONFIG = {"gcs_destination": {"output_uri_prefix": "gs://source_bucket_uri/output/"}} + batch_translation_results_data = {"batch_text_translate_results": OUTPUT_CONFIG["gcs_destination"]} + mock_hook.return_value.batch_translate_text.return_value = batch_translation_results_data + + op = TranslateTextBatchOperator( + task_id="task_id_test", + project_id=PROJECT_ID, + source_language_code=SRC_LANG_CODE, + target_language_codes=TARGET_LANG_CODES, + location=LOCATION, + models=None, + glossaries=None, + input_configs=INPUT_CONFIGS, + output_config=OUTPUT_CONFIG, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT, + retry=None, + ) + context = {"ti": mock.MagicMock()} + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.batch_translate_text.assert_called_once_with( + project_id=PROJECT_ID, + source_language_code=SRC_LANG_CODE, + target_language_codes=TARGET_LANG_CODES, + location=LOCATION, + input_configs=INPUT_CONFIGS, + output_config=OUTPUT_CONFIG, + timeout=TIMEOUT, + models=None, + glossaries=None, + labels=None, + retry=None, + metadata=(), + ) + assert batch_translation_results_data == result + + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + project_id=PROJECT_ID, + output_config=OUTPUT_CONFIG, + ) diff --git a/providers/tests/system/google/cloud/translate/example_translate_text.py b/providers/tests/system/google/cloud/translate/example_translate_text.py new file mode 100644 index 0000000000000..181d36bfe5d8e --- /dev/null +++ b/providers/tests/system/google/cloud/translate/example_translate_text.py @@ -0,0 +1,112 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that translates text in Google Cloud Translate using V3 API version +service in the Google Cloud. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.operators.translate import ( + TranslateTextBatchOperator, + TranslateTextOperator, +) +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "gcp_translate_text" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +REGION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +BATCH_TRANSLATE_SAMPLE_URI = ( + f"gs://{RESOURCE_DATA_BUCKET}/translate/V3/text_batch/inputs/translate_sample_de_1.txt" +) +BATCH_TRANSLATE_INPUT = { + "gcs_source": {"input_uri": BATCH_TRANSLATE_SAMPLE_URI}, + "mime_type": "text/plain", +} +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +GCS_OUTPUT_DST = { + "gcs_destination": {"output_uri_prefix": f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/translate_output/"} +} + + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2024, 11, 1), + catchup=False, + tags=["example", "translate_text", "batch_translate_text"], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + # [START howto_operator_translate_text_advanced] + translate_text = TranslateTextOperator( + task_id="translate_v3_op", + contents=["Ciao mondo!", "Mi puoi prendere una tazza di caffè, per favore?"], + source_language_code="it", + target_language_code="en", + ) + # [END howto_operator_translate_text_advanced] + + # [START howto_operator_batch_translate_text] + batch_text_translate = TranslateTextBatchOperator( + task_id="batch_translate_v3_op", + source_language_code="de", + target_language_codes=["en"], # Up to 10 language codes per run + location="us-central1", + input_configs=[BATCH_TRANSLATE_INPUT], + output_config=GCS_OUTPUT_DST, + ) + # [END howto_operator_batch_translate_text] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + create_bucket + # TEST BODY + >> [translate_text, batch_text_translate] + # TEST TEARDOWN + >> delete_bucket + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)