From de529314cce640833452b42c9eafd66e16cf3f82 Mon Sep 17 00:00:00 2001 From: e-halan Date: Tue, 30 Apr 2024 18:04:27 +0000 Subject: [PATCH] Reroute AutoML operator links to Google Translation links --- .../providers/google/cloud/links/automl.py | 38 ++++ .../providers/google/cloud/links/translate.py | 180 ++++++++++++++++++ .../google/cloud/operators/automl.py | 56 +++--- airflow/providers/google/provider.yaml | 6 + .../google/cloud/operators/test_automl.py | 26 ++- 5 files changed, 276 insertions(+), 30 deletions(-) create mode 100644 airflow/providers/google/cloud/links/translate.py diff --git a/airflow/providers/google/cloud/links/automl.py b/airflow/providers/google/cloud/links/automl.py index 79561d5b48132..c66636b8332a9 100644 --- a/airflow/providers/google/cloud/links/automl.py +++ b/airflow/providers/google/cloud/links/automl.py @@ -21,6 +21,9 @@ from typing import TYPE_CHECKING +from deprecated import deprecated + +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.google.cloud.links.base import BaseGoogleLink if TYPE_CHECKING: @@ -44,6 +47,13 @@ ) +@deprecated( + reason=( + "Class `AutoMLDatasetLink` has been deprecated and will be removed after 31.12.2024." + "Please use `TranslationLegacyDatasetLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLDatasetLink(BaseGoogleLink): """Helper class for constructing AutoML Dataset link.""" @@ -65,6 +75,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLDatasetListLink` has been deprecated and will be removed after 31.12.2024." + "Please use `TranslationDatasetListLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLDatasetListLink(BaseGoogleLink): """Helper class for constructing AutoML Dataset List link.""" @@ -87,6 +104,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelLink` has been deprecated and will be removed after 31.12.2024." + "Please use `TranslationLegacyModelLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelLink(BaseGoogleLink): """Helper class for constructing AutoML Model link.""" @@ -114,6 +138,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelTrainLink` has been deprecated and will be removed after 31.12.2024." + "Please use `TranslationLegacyModelTrainLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelTrainLink(BaseGoogleLink): """Helper class for constructing AutoML Model Train link.""" @@ -138,6 +169,13 @@ def persist( ) +@deprecated( + reason=( + "Class `AutoMLModelPredictLink` has been deprecated and will be removed after 31.12.2024." + "Please use `TranslationLegacyModelPredictLink` from `airflow/providers/google/cloud/links/translate.py` instead." + ), + category=AirflowProviderDeprecationWarning, +) class AutoMLModelPredictLink(BaseGoogleLink): """Helper class for constructing AutoML Model Predict link.""" diff --git a/airflow/providers/google/cloud/links/translate.py b/airflow/providers/google/cloud/links/translate.py new file mode 100644 index 0000000000000..2ab38805aef1c --- /dev/null +++ b/airflow/providers/google/cloud/links/translate.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google Translate links.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BASE_LINK, BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +TRANSLATION_BASE_LINK = BASE_LINK + "/translation" +TRANSLATION_LEGACY_DATASET_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/sentences?project={project_id}" +) +TRANSLATION_DATASET_LIST_LINK = TRANSLATION_BASE_LINK + "/datasets?project={project_id}" +TRANSLATION_LEGACY_MODEL_LINK = ( + TRANSLATION_BASE_LINK + + "/locations/{location}/datasets/{dataset_id}/evaluate;modelId={model_id}?project={project_id}" +) +TRANSLATION_LEGACY_MODEL_TRAIN_LINK = ( + TRANSLATION_BASE_LINK + "/locations/{location}/datasets/{dataset_id}/train?project={project_id}" +) +TRANSLATION_LEGACY_MODEL_PREDICT_LINK = ( + TRANSLATION_BASE_LINK + + "/locations/{location}/datasets/{dataset_id}/predict;modelId={model_id}?project={project_id}" +) + + +class TranslationLegacyDatasetLink(BaseGoogleLink): + """ + Helper class for constructing Legacy Translation Dataset link. + + Legacy Datasets are created and managed by AutoML API. + """ + + name = "Translation Legacy Dataset" + key = "translation_legacy_dataset" + format_str = TRANSLATION_LEGACY_DATASET_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyDatasetLink.key, + value={"location": task_instance.location, "dataset_id": dataset_id, "project_id": project_id}, + ) + + +class TranslationDatasetListLink(BaseGoogleLink): + """Helper class for constructing Translation Dataset List link.""" + + name = "Translation Dataset List" + key = "translation_dataset_list" + format_str = TRANSLATION_DATASET_LIST_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationDatasetListLink.key, + value={ + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model" + key = "translation_legacy_model" + format_str = TRANSLATION_LEGACY_MODEL_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + dataset_id: str, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelLink.key, + value={ + "location": task_instance.location, + "dataset_id": dataset_id, + "model_id": model_id, + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelTrainLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model Train link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model Train" + key = "translation_legacy_model_train" + format_str = TRANSLATION_LEGACY_MODEL_TRAIN_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelTrainLink.key, + value={ + "location": task_instance.location, + "dataset_id": task_instance.model["dataset_id"], + "project_id": project_id, + }, + ) + + +class TranslationLegacyModelPredictLink(BaseGoogleLink): + """ + Helper class for constructing Translation Legacy Model Predict link. + + Legacy Models are created and managed by AutoML API. + """ + + name = "Translation Legacy Model Predict" + key = "translation_legacy_model_predict" + format_str = TRANSLATION_LEGACY_MODEL_PREDICT_LINK + + @staticmethod + def persist( + context: Context, + task_instance, + model_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=TranslationLegacyModelPredictLink.key, + value={ + "location": task_instance.location, + "dataset_id": task_instance.model["dataset_id"], + "model_id": model_id, + "project_id": project_id, + }, + ) diff --git a/airflow/providers/google/cloud/operators/automl.py b/airflow/providers/google/cloud/operators/automl.py index d5dbb3b920996..b0db2a1d845e6 100644 --- a/airflow/providers/google/cloud/operators/automl.py +++ b/airflow/providers/google/cloud/operators/automl.py @@ -40,10 +40,13 @@ from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook from airflow.providers.google.cloud.links.automl import ( AutoMLDatasetLink, - AutoMLDatasetListLink, - AutoMLModelLink, - AutoMLModelPredictLink, - AutoMLModelTrainLink, +) +from airflow.providers.google.cloud.links.translate import ( + TranslationDatasetListLink, + TranslationLegacyDatasetLink, + TranslationLegacyModelLink, + TranslationLegacyModelPredictLink, + TranslationLegacyModelTrainLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID @@ -119,8 +122,8 @@ class AutoMLTrainModelOperator(GoogleCloudBaseOperator): "impersonation_chain", ) operator_extra_links = ( - AutoMLModelTrainLink(), - AutoMLModelLink(), + TranslationLegacyModelTrainLink(), + TranslationLegacyModelLink(), ) def __init__( @@ -173,7 +176,9 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - AutoMLModelTrainLink.persist(context=context, task_instance=self, project_id=project_id) + TranslationLegacyModelTrainLink.persist( + context=context, task_instance=self, project_id=project_id + ) operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) result = Model.to_dict(operation_result) model_id = hook.extract_object_id(result) @@ -181,7 +186,7 @@ def execute(self, context: Context): self.xcom_push(context, key="model_id", value=model_id) if project_id: - AutoMLModelLink.persist( + TranslationLegacyModelLink.persist( context=context, task_instance=self, dataset_id=self.model["dataset_id"] or "-", @@ -195,6 +200,9 @@ class AutoMLPredictOperator(GoogleCloudBaseOperator): """ Runs prediction operation on Google Cloud AutoML. + AutoMLPredictOperator for text, image, and video prediction has been deprecated. + Please use endpoint_id param instead of model_id param. + .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:AutoMLPredictOperator` @@ -228,7 +236,7 @@ class AutoMLPredictOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelPredictLink(),) + operator_extra_links = (TranslationLegacyModelPredictLink(),) def __init__( self, @@ -325,7 +333,7 @@ def execute(self, context: Context): project_id = self.project_id or hook.project_id if project_id and self.model_id: - AutoMLModelPredictLink.persist( + TranslationLegacyModelPredictLink.persist( context=context, task_instance=self, model_id=self.model_id, @@ -389,7 +397,7 @@ class AutoMLBatchPredictOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelPredictLink(),) + operator_extra_links = (TranslationLegacyModelPredictLink(),) def __init__( self, @@ -462,7 +470,7 @@ def execute(self, context: Context): self.log.info("Batch prediction is ready.") project_id = self.project_id or hook.project_id if project_id: - AutoMLModelPredictLink.persist( + TranslationLegacyModelPredictLink.persist( context=context, task_instance=self, model_id=self.model_id, @@ -511,7 +519,7 @@ class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -560,7 +568,7 @@ def execute(self, context: Context): self.xcom_push(context, key="dataset_id", value=dataset_id) project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=dataset_id, @@ -611,7 +619,7 @@ class AutoMLImportDataOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -668,7 +676,7 @@ def execute(self, context: Context): self.log.info("Import is completed") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -722,7 +730,7 @@ class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -777,7 +785,7 @@ def execute(self, context: Context): self.log.info("Columns specs obtained.") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -924,7 +932,7 @@ class AutoMLGetModelOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLModelLink(),) + operator_extra_links = (TranslationLegacyModelLink(),) def __init__( self, @@ -968,7 +976,7 @@ def execute(self, context: Context): model = Model.to_dict(result) project_id = self.project_id or hook.project_id if project_id: - AutoMLModelLink.persist( + TranslationLegacyModelLink.persist( context=context, task_instance=self, dataset_id=model["dataset_id"], @@ -1223,7 +1231,7 @@ class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetLink(),) + operator_extra_links = (TranslationLegacyDatasetLink(),) def __init__( self, @@ -1273,7 +1281,7 @@ def execute(self, context: Context): self.log.info("Table specs obtained.") project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetLink.persist( + TranslationLegacyDatasetLink.persist( context=context, task_instance=self, dataset_id=self.dataset_id, @@ -1318,7 +1326,7 @@ class AutoMLListDatasetOperator(GoogleCloudBaseOperator): "project_id", "impersonation_chain", ) - operator_extra_links = (AutoMLDatasetListLink(),) + operator_extra_links = (TranslationDatasetListLink(),) def __init__( self, @@ -1373,7 +1381,7 @@ def execute(self, context: Context): ) project_id = self.project_id or hook.project_id if project_id: - AutoMLDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) + TranslationDatasetListLink.persist(context=context, task_instance=self, project_id=project_id) return result diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 3144b3bbc1483..ccfc030915b93 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -1271,6 +1271,12 @@ extra-links: - airflow.providers.google.common.links.storage.StorageLink - airflow.providers.google.common.links.storage.FileDetailsLink - airflow.providers.google.marketing_platform.links.analytics_admin.GoogleAnalyticsPropertyLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyDatasetLink + - airflow.providers.google.cloud.links.translate.TranslationDatasetListLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelTrainLink + - airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink + secrets-backends: - airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend diff --git a/tests/providers/google/cloud/operators/test_automl.py b/tests/providers/google/cloud/operators/test_automl.py index cecda2bf23a21..f7bef5452193f 100644 --- a/tests/providers/google/cloud/operators/test_automl.py +++ b/tests/providers/google/cloud/operators/test_automl.py @@ -139,12 +139,13 @@ def test_templating(self, create_task_instance_of_operator): class TestAutoMLBatchPredictOperator: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_link_persist): mock_hook.return_value.batch_predict.return_value.result.return_value = BatchPredictResult() mock_hook.return_value.extract_object_id = extract_object_id mock_hook.return_value.wait_for_operation.return_value = BatchPredictResult() - + mock_context = {"ti": mock.MagicMock()} op = AutoMLBatchPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -154,7 +155,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, prediction_params={}, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock_context) mock_hook.return_value.batch_predict.assert_called_once_with( input_config=INPUT_CONFIG, location=GCP_LOCATION, @@ -166,6 +167,12 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) + mock_link_persist.assert_called_once_with( + context=mock_context, + task_instance=op, + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + ) @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") def test_execute_deprecated(self, mock_hook): @@ -226,10 +233,11 @@ def test_templating(self, create_task_instance_of_operator): class TestAutoMLPredictOperator: + @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, mock_link_persist): mock_hook.return_value.predict.return_value = PredictResponse() - + mock_context = {"ti": mock.MagicMock()} op = AutoMLPredictOperator( model_id=MODEL_ID, location=GCP_LOCATION, @@ -238,7 +246,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, operation_params={"TEST_KEY": "TEST_VALUE"}, ) - op.execute(context=mock.MagicMock()) + op.execute(context=mock_context) mock_hook.return_value.predict.assert_called_once_with( location=GCP_LOCATION, metadata=(), @@ -249,6 +257,12 @@ def test_execute(self, mock_hook): retry=DEFAULT, timeout=None, ) + mock_link_persist.assert_called_once_with( + context=mock_context, + task_instance=op, + model_id=MODEL_ID, + project_id=GCP_PROJECT_ID, + ) @pytest.mark.db_test def test_templating(self, create_task_instance_of_operator):