Skip to content

Commit

Permalink
Introduce new gcp TranslateText and TranslateTextBatch operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleg Kachur committed Nov 6, 2024
1 parent bf00235 commit a985a0d
Show file tree
Hide file tree
Showing 8 changed files with 937 additions and 10 deletions.
238 changes: 235 additions & 3 deletions airflow/providers/google/cloud/hooks/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,38 @@

from __future__ import annotations

from typing import Sequence
from google.api_core.operation import Operation

from google.api_core.exceptions import GoogleAPICallError

from typing import (
TYPE_CHECKING,
MutableMapping,
MutableSequence,
Sequence,
cast,
)

from google.cloud.translate_v2 import Client

from google.cloud.translate_v3 import TranslationServiceClient

from google.cloud.translate_v3.types import (
TransliterationConfig,
TranslateTextGlossaryConfig,
InputConfig,
OutputConfig
)

from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault

from airflow.exceptions import AirflowException
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook

if TYPE_CHECKING:
from google.api_core.retry import Retry


class CloudTranslateHook(GoogleBaseHook):
Expand Down Expand Up @@ -106,11 +132,217 @@ def translate(
values and translations differ.
"""
client = self.get_conn()

return client.translate(
values=values,
target_language=target_language,
format_=format_,
source_language=source_language,
model=model,
)


class TranslateHook(GoogleBaseHook):
"""
Hook for Google Cloud translation (Advanced) using client version V3
(https://cloud.google.com/translate/docs/editions#advanced).
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:

super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self._client: TranslationServiceClient | None = None

def get_client(self) -> TranslationServiceClient:
"""
Retrieve TranslationService client.
:return: Google Cloud Translation Service client object.
"""
if self._client is None:
self._client = TranslationServiceClient(
credentials=self.get_credentials(), client_info=CLIENT_INFO
)
return self._client

@staticmethod
def wait_for_operation(operation: Operation, timeout: int | None = None):
"""Wait for long-lasting operation to complete."""
try:
return operation.result(timeout=timeout)
except GoogleAPICallError:
error = operation.exception(timeout=timeout)
raise AirflowException(error)

def translate_text(
self,
*,
project_id: str = PROVIDE_PROJECT_ID,
contents: Sequence[str],
target_language_code: str,
source_language_code: str | None = None,
mime_type: str | None = None,
location: str | None = None,
model: str | None = None,
transliteration_config: TransliterationConfig | None = None,
glossary_config: TranslateTextGlossaryConfig | None = None,
labels: str | None = None,
timeout: float | _MethodDefault = DEFAULT,
metadata: Sequence[tuple[str, str]] = (),
retry: Retry | _MethodDefault | None = DEFAULT,
) -> dict:
"""
Translate text content provided
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param contents: Required. The content of the input in string
format. Max length 1024 items with 30_000 codepoints recommended.
:param mime_type: Optional. The format of the source text, If left
blank, the MIME type defaults to "text/html".
:param source_language_code: Optional. The ISO-639 language code of the
input text if known. If the source language
isn't specified, the API attempts to identify
the source language automatically and returns
the source language within the response.
:param target_language_code: Required. The ISO-639 language code to use
for translation of the input text
:param location: Optional. Project or location to make a call. Must refer to
a caller's project.
If not specified, 'global used'.
Non-global location is required for requests using AutoML
models or custom glossaries.
Models and glossaries must be within the same region (have
same location-id).
:param model: Optional. The ``model`` type requested for this translation.
If not provided, the default Google model (NMT) will be used.
The format depends on model type:
- AutoML Translation models:
``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}``
- General (built-in) models:
``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt``,
- Translation LLM models:
``projects/{project-number-or-id}/locations/{location-id}/models/general/translation-llm``,
For global (non-regionalized) requests, use ``location-id``
``global``. For example, ``projects/{project-number-or-id}/locations/global/models/general/nmt``.
:param glossary_config: Optional. Glossary to be applied. The glossary must be
within the same region (have the same location-id) as the
model.
:param transliteration_config: Optional. Transliteration to be applied.
:param labels: Optional. The labels with user-defined
metadata for the request.
See https://cloud.google.com/translate/docs/advanced/labels for more information.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
:return: Translate text result from the API response.
"""
client = self.get_client()
location_id = 'global' if not location else location
parent = f"projects/{project_id or self.project_id}/locations/{location_id}"

result = client.translate_text(
request={
'parent': parent,
'source_language_code': source_language_code,
'target_language_code': target_language_code,
'contents': contents,
'mime_type': mime_type,
'glossary_config': glossary_config,
'transliteration_config': transliteration_config,
'model': model,
'labels': labels,
},
timeout=timeout,
retry=retry,
metadata=metadata,
)
return cast(dict, type(result).to_dict(result))

def batch_translate_text(
self,
*,
project_id: str = PROVIDE_PROJECT_ID,
location: str,
source_language_code: str,
target_language_codes: MutableSequence[str],
input_configs: MutableSequence[InputConfig | dict],
output_config: OutputConfig | dict,
models: str | None = None,
glossaries: MutableMapping[str, TranslateTextGlossaryConfig] | None = None,
labels: MutableMapping[str, str] | None = None,
timeout: float | _MethodDefault = DEFAULT,
metadata: Sequence[tuple[str, str]] = (),
retry: Retry | _MethodDefault | None = DEFAULT,

) -> Operation:
"""
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Optional. Project or location to make a call. Must refer to
a caller's project. Must be non-global.
:param source_language_code: Required. Source language code.
:param target_language_codes: Required. Specify up to 10 language codes
here.
:param models: Optional. The models to use for translation. Map's key is
target language code. Map's value is model name. Value can
be a built-in general model, or an AutoML Translation model.
The value format depends on model type:
- AutoML Translation models:
``projects/{project-number-or-id}/locations/{location-id}/models/{model-id}``
- General (built-in) models:
``projects/{project-number-or-id}/locations/{location-id}/models/general/nmt``,
If the map is empty or a specific model is not requested for
a language pair, then default google model (nmt) is used.
:param input_configs: Required. Input configurations.
The total number of files matched should be <=
100. The total content size should be <= 100M
Unicode codepoints. The files must use UTF-8
encoding.
:param output_config: Required. Output configuration.
:param glossaries: Optional. Glossaries to be applied for
translation. It's keyed by target language code.
:param labels: Optional. The labels with user-defined metadata for the request.
See https://cloud.google.com/translate/docs/advanced/labels for more information.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
:returns: Operation object with the batch text translate results,
that are returned by batches as they are ready.
"""
client = self.get_client()
if location == 'global':
raise AirflowException(
'Global location is not allowed for the batch text translation, '
'please provide the correct value!'
)
parent = f"projects/{project_id or self.project_id}/locations/{location}"
result = client.batch_translate_text(
request={
'parent': parent,
'source_language_code': source_language_code,
'target_language_codes': target_language_codes,
'input_configs': input_configs,
'output_config': output_config,
'glossaries': glossaries,
'models': models,
'labels': labels,
},
timeout=timeout,
retry=retry,
metadata=metadata,
)
return result
33 changes: 33 additions & 0 deletions airflow/providers/google/cloud/links/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
+ "/locations/{location}/datasets/{dataset_id}/predict;modelId={model_id}?project={project_id}"
)

TRANSLATION_TRANSLATE_TEXT_BATCH = BASE_LINK + "/storage/browser/{output_uri_prefix}?project={project_id}"


class TranslationLegacyDatasetLink(BaseGoogleLink):
"""
Expand Down Expand Up @@ -178,3 +180,34 @@ def persist(
"project_id": project_id,
},
)


class TranslateTextBatchLink(BaseGoogleLink):
"""
Helper class for constructing Translation results for the text batch translate.
"""

name = "Text Translate Batch"
key = "translate_text_batch"
format_str = TRANSLATION_TRANSLATE_TEXT_BATCH

@staticmethod
def extract_output_uri_prefix(output_config):
return output_config["gcs_destination"]["output_uri_prefix"].rpartition("gs://")[-1]

@staticmethod
def persist(
context: Context,
task_instance,
project_id: str,
output_config: dict,
):
task_instance.xcom_push(
context,
key=TranslateTextBatchLink.key,
value={
"project_id": project_id,
"output_uri_prefix": TranslateTextBatchLink.extract_output_uri_prefix(output_config),
},
)
Loading

0 comments on commit a985a0d

Please sign in to comment.