From 46cbd1c064d4ef5df30ca46c1399841bfefb44ac Mon Sep 17 00:00:00 2001 From: Oleg Kachur Date: Tue, 19 Nov 2024 19:02:11 +0000 Subject: [PATCH] Introduce the translation API v3 (advanced) models operators. - TranslateCreateModelOperator - TranslateModelsListOperator - TranslateDeleteModelOperator More details on using AutoML translation: https://cloud.google.com/translate/docs/advanced/automl-beginner. --- .../operators/cloud/translate.rst | 63 +++++ docs/spelling_wordlist.txt | 1 + .../providers/google/cloud/hooks/translate.py | 154 +++++++++++ .../providers/google/cloud/links/translate.py | 63 +++++ .../google/cloud/operators/translate.py | 255 ++++++++++++++++++ .../google/cloud/operators/test_translate.py | 156 +++++++++++ .../translate/example_translate_model.py | 178 ++++++++++++ 7 files changed, 870 insertions(+) create mode 100644 providers/tests/system/google/cloud/translate/example_translate_model.py diff --git a/docs/apache-airflow-providers-google/operators/cloud/translate.rst b/docs/apache-airflow-providers-google/operators/cloud/translate.rst index d56fac26dbeb4..5bda3d9085a6c 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/translate.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/translate.rst @@ -184,6 +184,69 @@ Basic usage of the operator: :end-before: [END howto_operator_translate_automl_delete_dataset] +.. _howto/operator:TranslateCreateModelOperator: + +TranslateCreateModelOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create a native translation model using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_create_model] + :end-before: [END howto_operator_translate_automl_create_model] + + +.. _howto/operator:TranslateModelsListOperator: + +TranslateModelsListOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Get list of native translation models using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateModelsListOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_list_models] + :end-before: [END howto_operator_translate_automl_list_models] + + +.. _howto/operator:TranslateDeleteModelOperator: + +TranslateDeleteModelOperator +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Delete a native translation model using Cloud Translate API (Advanced V3). + +For parameter definition, take a look at +:class:`~airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator` + +Using the operator +"""""""""""""""""" + +Basic usage of the operator: + +.. exampleinclude:: /../../providers/tests/system/google/cloud/translate/example_translate_model.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_translate_automl_delete_model] + :end-before: [END howto_operator_translate_automl_delete_model] + + More information """""""""""""""""" See: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index fa8ffb4a2c3c6..dad1f8705f46e 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -970,6 +970,7 @@ linux ListDatasetsPager ListGenerator ListInfoTypesResponse +ListModelsPager ListSecretsPager Liveness liveness diff --git a/providers/src/airflow/providers/google/cloud/hooks/translate.py b/providers/src/airflow/providers/google/cloud/hooks/translate.py index 6ddb220f3e789..09f25333e151d 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/translate.py +++ b/providers/src/airflow/providers/google/cloud/hooks/translate.py @@ -562,3 +562,157 @@ def delete_dataset( metadata=metadata, ) return result + + def create_model( + self, + dataset_id: str, + display_name: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Create the native model by training on translation dataset provided. + + :param dataset_id: ID of dataset to be used for model training. + :param display_name: Display name of the model trained. + A-Z and a-z, underscores (_), and ASCII digits 0-9. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object with the model creation results, when finished. + """ + client = self.get_client() + project_id = project_id or self.project_id + parent = f"projects/{project_id}/locations/{location}" + dataset = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" + result = client.create_model( + request={ + "parent": parent, + "model": { + "display_name": display_name, + "dataset": dataset, + }, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def get_model( + self, + model_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> automl_translation.Model: + """ + Retrieve the dataset for the given model_id. + + :param model_id: ID of translation model to be retrieved. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `automl_translation.Model` instance. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + return client.get_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_models( + self, + project_id: str, + location: str, + filter_str: str | None = None, + page_size: int | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> pagers.ListModelsPager: + """ + List translation models in a project. + + :param project_id: ID of the Google Cloud project where models are located. If not provided + default project_id is used. + :param location: The location of the project. + :param filter_str: An optional expression for filtering the models that will + be returned. Supported filter: ``dataset_id=${dataset_id}``. + :param page_size: Optional custom page size value. The server can + return fewer results than requested. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: ``pagers.ListDatasetsPager`` instance, iterable object to retrieve the datasets list. + """ + client = self.get_client() + parent = f"projects/{project_id}/locations/{location}" + result = client.list_models( + request={ + "parent": parent, + "filter": filter_str, + "page_size": page_size, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + def delete_model( + self, + model_id: str, + project_id: str, + location: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete the translation model and all of its contents. + + :param model_id: ID of model to be deleted. + :param project_id: ID of the Google Cloud project where dataset is located. If not provided + default project_id is used. + :param location: The location of the project. + :param retry: A retry object used to retry requests. If `None` is specified, requests will not be + retried. + :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if + `retry` is specified, the timeout applies to each individual attempt. + :param metadata: Additional metadata that is provided to the method. + + :return: `Operation` object with dataset deletion results, when finished. + """ + client = self.get_client() + name = f"projects/{project_id}/locations/{location}/models/{model_id}" + result = client.delete_model( + request={"name": name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result diff --git a/providers/src/airflow/providers/google/cloud/links/translate.py b/providers/src/airflow/providers/google/cloud/links/translate.py index 0d1489ddcfc81..55db26508388d 100644 --- a/providers/src/airflow/providers/google/cloud/links/translate.py +++ b/providers/src/airflow/providers/google/cloud/links/translate.py @@ -50,6 +50,12 @@ ) TRANSLATION_NATIVE_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" +TRANSLATION_NATIVE_MODEL_LINK = ( + TRANSLATION_BASE_LINK + + "/locations/{location}/datasets/{dataset_id}/evaluate;modelId={model_id}?project={project_id}" +) +TRANSLATION_MODELS_LIST_LINK = TRANSLATION_BASE_LINK + "/models/list?project={project_id}" + class TranslationLegacyDatasetLink(BaseGoogleLink): """ @@ -270,3 +276,60 @@ def persist( "project_id": project_id, }, ) + + +class TranslationModelLink(BaseGoogleLink): + """ + Helper class for constructing Translation Model link. + + Link for legacy and native models. + """ + + name = "Translation Model" + key = "translation_model" + format_str = TRANSLATION_NATIVE_MODEL_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelLink.key, + value={ + "location": task_instance.location, + "dataset_id": dataset_id, + "model_id": model_id, + "project_id": project_id, + }, + ) + + +class TranslationModelsListLink(BaseGoogleLink): + """ + Helper class for constructing Translation Models List link. + + Both legacy and native models are available under this link. + """ + + name = "Translation Models List" + key = "translation_models_list" + format_str = TRANSLATION_MODELS_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationModelsListLink.key, + value={ + "project_id": project_id, + }, + ) diff --git a/providers/src/airflow/providers/google/cloud/operators/translate.py b/providers/src/airflow/providers/google/cloud/operators/translate.py index d384e9b8efa9b..7233e38743791 100644 --- a/providers/src/airflow/providers/google/cloud/operators/translate.py +++ b/providers/src/airflow/providers/google/cloud/operators/translate.py @@ -29,6 +29,8 @@ from airflow.providers.google.cloud.links.translate import ( TranslateTextBatchLink, TranslationDatasetsListLink, + TranslationModelLink, + TranslationModelsListLink, TranslationNativeDatasetLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator @@ -722,3 +724,256 @@ def execute(self, context: Context): ) hook.wait_for_operation_done(operation=operation, timeout=self.timeout) self.log.info("Dataset deletion complete!") + + +class TranslateCreateModelOperator(GoogleCloudBaseOperator): + """ + Creates a Google Cloud Translate model. + + Creates a `native` translation model, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateCreateModelOperator`. + + :param dataset_id: The dataset id used for model training. + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "dataset_id", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationModelLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + dataset_id: str, + display_name: str, + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + metadata: Sequence[tuple[str, str]] = (), + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.dataset_id = dataset_id + self.display_name = display_name + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context) -> str: + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + self.log.info("Model creation started, dataset_id %s...", self.dataset_id) + try: + result_operation = hook.create_model( + dataset_id=self.dataset_id, + display_name=self.display_name, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + except GoogleAPICallError as e: + self.log.error("Error submitting create_model operation ") + raise AirflowException(e) + + self.log.info("Training has started") + hook.wait_for_operation_done(operation=result_operation) + result = hook.wait_for_operation_result(operation=result_operation) + result = type(result).to_dict(result) + model_id = hook.extract_object_id(result) + self.xcom_push(context, key="model_id", value=model_id) + self.log.info("Model creation complete. The model_id: %s.", model_id) + + project_id = self.project_id or hook.project_id + TranslationModelLink.persist( + context=context, + task_instance=self, + dataset_id=self.dataset_id, + model_id=model_id, + project_id=project_id, + ) + return result + + +class TranslateModelsListOperator(GoogleCloudBaseOperator): + """ + Get a list of native Google Cloud Translation models in a project. + + Get project's list of `native` translation models, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateModelsListOperator`. + + :param project_id: ID of the Google Cloud project where dataset is located. + If not provided default project_id is used. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + operator_extra_links = (TranslationModelsListLink(),) + + def __init__( + self, + *, + project_id: str = PROVIDE_PROJECT_ID, + location: str, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | _MethodDefault = DEFAULT, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + project_id = self.project_id or hook.project_id + TranslationModelsListLink.persist( + context=context, + task_instance=self, + project_id=project_id, + ) + self.log.info("Requesting models list") + results_pager = hook.list_models( + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + result_ids = [] + for model_item in results_pager: + model_data = type(model_item).to_dict(model_item) + model_id = hook.extract_object_id(model_data) + result_ids.append(model_id) + self.log.info("Fetching the models list complete. Model id-s: %s", result_ids) + return result_ids + + +class TranslateDeleteModelOperator(GoogleCloudBaseOperator): + """ + Delete translation model and all of its contents. + + Deletes the translation model and it's data, using API V3. + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:TranslateDeleteModelOperator`. + + :param model_id: The model_id of target native model to be deleted. + :param location: The location of the project. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "model_id", + "location", + "project_id", + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + *, + model_id: str, + location: str, + project_id: str = PROVIDE_PROJECT_ID, + metadata: Sequence[tuple[str, str]] = (), + timeout: float | None = None, + retry: Retry | _MethodDefault = DEFAULT, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.model_id = model_id + self.project_id = project_id + self.location = location + self.metadata = metadata + self.timeout = timeout + self.retry = retry + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def execute(self, context: Context): + hook = TranslateHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + self.log.info("Deleting the model %s...", self.model_id) + operation = hook.delete_model( + model_id=self.model_id, + location=self.location, + project_id=self.project_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + hook.wait_for_operation_done(operation=operation, timeout=self.timeout) + self.log.info("Model deletion complete!") diff --git a/providers/tests/google/cloud/operators/test_translate.py b/providers/tests/google/cloud/operators/test_translate.py index 45af2dae92890..a732b0a20474f 100644 --- a/providers/tests/google/cloud/operators/test_translate.py +++ b/providers/tests/google/cloud/operators/test_translate.py @@ -26,9 +26,12 @@ from airflow.providers.google.cloud.operators.translate import ( CloudTranslateTextOperator, TranslateCreateDatasetOperator, + TranslateCreateModelOperator, TranslateDatasetsListOperator, TranslateDeleteDatasetOperator, + TranslateDeleteModelOperator, TranslateImportDataOperator, + TranslateModelsListOperator, TranslateTextBatchOperator, TranslateTextOperator, ) @@ -39,6 +42,7 @@ IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] PROJECT_ID = "test-project-id" DATASET_ID = "sample_ds_id" +MODEL_ID = "sample_model_id" TIMEOUT_VALUE = 30 @@ -386,3 +390,155 @@ def test_minimal_green_path(self, mock_hook): metadata=(), ) wait_for_done.assert_called_once_with(operation=m_delete_method_result, timeout=TIMEOUT_VALUE) + + +class TestTranslateModelCreate: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator.xcom_push") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_xcom_push, mock_link_persist): + MODEL_DISPLAY_NAME = "model_display_name_01" + MODEL_CREATION_RESULT_SAMPLE = { + "display_name": MODEL_DISPLAY_NAME, + "name": f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID}", + "dataset": f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/{DATASET_ID}", + "source_language_code": "", + "target_language_code": "", + "create_time": "2024-11-15T14:05:00Z", + "update_time": "2024-11-16T01:09:03Z", + "test_example_count": 1000, + "train_example_count": 115, + "validate_example_count": 140, + } + sample_operation = mock.MagicMock() + sample_operation.result.return_value = automl_translation.Model(MODEL_CREATION_RESULT_SAMPLE) + + mock_hook.return_value.create_model.return_value = sample_operation + mock_hook.return_value.wait_for_operation_result.side_effect = lambda operation: operation.result() + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + op = TranslateCreateModelOperator( + task_id="task_id", + display_name=MODEL_DISPLAY_NAME, + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=None, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.create_model.assert_called_once_with( + display_name=MODEL_DISPLAY_NAME, + dataset_id=DATASET_ID, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=None, + metadata=(), + ) + mock_xcom_push.assert_called_once_with(context, key="model_id", value=MODEL_ID) + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + model_id=MODEL_ID, + project_id=PROJECT_ID, + dataset_id=DATASET_ID, + ) + assert result == MODEL_CREATION_RESULT_SAMPLE + + +class TestTranslateListModels: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationModelsListLink.persist") + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook, mock_link_persist): + MODEL_ID_1 = "sample_model_1" + MODEL_ID_2 = "sample_model_2" + model_result_1 = automl_translation.Model( + dict( + display_name="model_1_display_name", + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_1}", + dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_1", + source_language_code="en", + target_language_code="es", + ) + ) + model_result_2 = automl_translation.Model( + dict( + display_name="model_2_display_name", + name=f"projects/{PROJECT_ID}/locations/{LOCATION}/models/{MODEL_ID_2}", + dataset=f"projects/{PROJECT_ID}/locations/{LOCATION}/datasets/ds_for_model_2", + source_language_code="uk", + target_language_code="en", + ) + ) + mock_hook.return_value.list_models.return_value = [model_result_1, model_result_2] + mock_hook.return_value.extract_object_id = TranslateHook.extract_object_id + + op = TranslateModelsListOperator( + task_id="task_id", + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + result = op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.list_models.assert_called_once_with( + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + assert result == [MODEL_ID_1, MODEL_ID_2] + mock_link_persist.assert_called_once_with( + context=context, + task_instance=op, + project_id=PROJECT_ID, + ) + + +class TestTranslateDeleteModel: + @mock.patch("airflow.providers.google.cloud.operators.translate.TranslateHook") + def test_minimal_green_path(self, mock_hook): + m_delete_method_result = mock.MagicMock() + mock_hook.return_value.delete_model.return_value = m_delete_method_result + wait_for_done = mock_hook.return_value.wait_for_operation_done + + op = TranslateDeleteModelOperator( + task_id="task_id", + model_id=MODEL_ID, + project_id=PROJECT_ID, + location=LOCATION, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + ) + context = mock.MagicMock() + op.execute(context=context) + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.delete_model.assert_called_once_with( + model_id=MODEL_ID, + project_id=PROJECT_ID, + location=LOCATION, + timeout=TIMEOUT_VALUE, + retry=DEFAULT, + metadata=(), + ) + wait_for_done.assert_called_once_with(operation=m_delete_method_result, timeout=TIMEOUT_VALUE) diff --git a/providers/tests/system/google/cloud/translate/example_translate_model.py b/providers/tests/system/google/cloud/translate/example_translate_model.py new file mode 100644 index 0000000000000..3514668fa2df8 --- /dev/null +++ b/providers/tests/system/google/cloud/translate/example_translate_model.py @@ -0,0 +1,178 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG that translates text in Google Cloud Translate using V3 API version +service in the Google Cloud. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator +from airflow.providers.google.cloud.operators.translate import ( + TranslateCreateDatasetOperator, + TranslateCreateModelOperator, + TranslateDeleteDatasetOperator, + TranslateDeleteModelOperator, + TranslateImportDataOperator, + TranslateModelsListOperator, + TranslateTextOperator, +) +from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator +from airflow.utils.trigger_rule import TriggerRule + +DAG_ID = "gcp_translate_automl_native_model" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +REGION = "us-central1" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +DATA_SAMPLE_GCS_BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}".replace("_", "-") +DATA_FILE_NAME = "import_en-es_short.tsv" +RESOURCE_PATH = f"V3_translate/create_ds/import_data/{DATA_FILE_NAME}" +COPY_DATA_PATH = f"gs://{RESOURCE_DATA_BUCKET}/V3_translate/create_ds/import_data/{DATA_FILE_NAME}" +DST_PATH = f"translate/import/{DATA_FILE_NAME}" +DATASET_DATA_PATH = f"gs://{DATA_SAMPLE_GCS_BUCKET_NAME}/{DST_PATH}" +DATASET = { + "display_name": f"op_ds_native{DAG_ID}_{ENV_ID}", + "source_language_code": "es", + "target_language_code": "en", +} + + +with DAG( + DAG_ID, + schedule="@once", # Override to match your needs + start_date=datetime(2024, 11, 1), + catchup=False, + tags=[ + "example", + "translate_model", + ], +) as dag: + create_bucket = GCSCreateBucketOperator( + task_id="create_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + storage_class="REGIONAL", + location=REGION, + ) + copy_dataset_source_tsv = GCSToGCSOperator( + task_id="copy_dataset_file", + source_bucket=RESOURCE_DATA_BUCKET, + source_object=RESOURCE_PATH, + destination_bucket=DATA_SAMPLE_GCS_BUCKET_NAME, + destination_object=DST_PATH, + ) + + create_dataset_op = TranslateCreateDatasetOperator( + task_id="translate_v3_ds_create", + dataset=DATASET, + project_id=PROJECT_ID, + location=REGION, + ) + + import_ds_data_op = TranslateImportDataOperator( + task_id="translate_v3_ds_import_data", + dataset_id=create_dataset_op.output["dataset_id"], + input_config={ + "input_files": [{"usage": "UNASSIGNED", "gcs_source": {"input_uri": DATASET_DATA_PATH}}] + }, + project_id=PROJECT_ID, + location=REGION, + ) + + # [START howto_operator_translate_automl_create_model] + create_model = TranslateCreateModelOperator( + task_id="translate_v3_model_create", + display_name=f"native_model_{ENV_ID}"[:32].replace("-", "_"), + dataset_id=create_dataset_op.output["dataset_id"], + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_create_model] + + # [START howto_operator_translate_automl_list_models] + list_models = TranslateModelsListOperator( + task_id="translate_v3_list_models", + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_list_models] + + model_id = create_model.output["model_id"] + + translate_text_with_model = TranslateTextOperator( + task_id="translate_v3_op", + contents=["Hola!", "Puedes traerme una taza de café, por favor?"], + # AutoML model format + model=f"projects/{PROJECT_ID}/locations/{REGION}/models/{model_id}", + source_language_code="es", + target_language_code="en", + ) + + # [START howto_operator_translate_automl_delete_model] + delete_model = TranslateDeleteModelOperator( + task_id="translate_v3_automl_delete_model", + model_id=model_id, + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_delete_model] + + delete_ds_op = TranslateDeleteDatasetOperator( + task_id="translate_v3_ds_delete", + dataset_id=create_dataset_op.output["dataset_id"], + project_id=PROJECT_ID, + location=REGION, + ) + # [END howto_operator_translate_automl_delete_dataset] + + delete_bucket = GCSDeleteBucketOperator( + task_id="delete_bucket", + bucket_name=DATA_SAMPLE_GCS_BUCKET_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + # TEST SETUP + [create_bucket >> copy_dataset_source_tsv] + >> create_dataset_op + >> import_ds_data_op + # TEST BODY + >> create_model + >> list_models + >> translate_text_with_model + >> delete_model + # TEST TEARDOWN + >> delete_ds_op + >> delete_bucket + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)