diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index f8f87040f9fc3..173b23dfa3002 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -573,16 +573,6 @@ To get a pipeline job list you can use Interacting with Generative AI ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -To generate a prediction via language model you can use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator`. -The operator returns the model's response in :ref:`XCom ` under ``model_response`` key. - -.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_text_generation_model_predict_operator] - :end-before: [END how_to_cloud_vertex_ai_text_generation_model_predict_operator] - To generate text embeddings you can use :class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator`. The operator returns the model's response in :ref:`XCom ` under ``model_response`` key. diff --git a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py index 931cc19273781..7e506641484b3 100644 --- a/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py @@ -43,6 +43,11 @@ class GenerativeModelHook(GoogleBaseHook): """Hook for Google Cloud Vertex AI Generative Model APIs.""" + @deprecated( + planned_removal_date="April 09, 2025", + use_instead="GenerativeModelHook.get_generative_model", + category=AirflowProviderDeprecationWarning, + ) def get_text_generation_model(self, pretrained_model: str): """Return a Model Garden Model object based on Text Generation.""" model = TextGenerationModel.from_pretrained(pretrained_model) @@ -275,6 +280,11 @@ def prompt_multimodal_model_with_media( return response.text + @deprecated( + planned_removal_date="April 09, 2025", + use_instead="GenerativeModelHook.generative_model_generate_content", + category=AirflowProviderDeprecationWarning, + ) @GoogleBaseHook.fallback_to_default_project_id def text_generation_model_predict( self, diff --git a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py index 78eba3338961a..42e4fdc588e43 100644 --- a/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py +++ b/providers/src/airflow/providers/google/cloud/operators/vertex_ai/generative_model.py @@ -353,6 +353,11 @@ def execute(self, context: Context): return response +@deprecated( + planned_removal_date="April 09, 2025", + use_instead="GenerativeModelGenerateContentOperator", + category=AirflowProviderDeprecationWarning, +) class TextGenerationModelPredictOperator(GoogleCloudBaseOperator): """ Uses the Vertex AI PaLM API to generate natural language text. diff --git a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py index 35d3fc9256e6c..21741a617ea92 100644 --- a/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/hooks/vertex_ai/test_generative_model.py @@ -205,24 +205,18 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model")) def test_text_generation_model_predict(self, mock_model) -> None: - self.hook.text_generation_model_predict( - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=TEST_PROMPT, - pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL, - temperature=TEST_TEMPERATURE, - max_output_tokens=TEST_MAX_OUTPUT_TOKENS, - top_p=TEST_TOP_P, - top_k=TEST_TOP_K, - ) - mock_model.assert_called_once_with(TEST_LANGUAGE_PRETRAINED_MODEL) - mock_model.return_value.predict.assert_called_once_with( - prompt=TEST_PROMPT, - temperature=TEST_TEMPERATURE, - max_output_tokens=TEST_MAX_OUTPUT_TOKENS, - top_p=TEST_TOP_P, - top_k=TEST_TOP_K, - ) + with pytest.warns(AirflowProviderDeprecationWarning) as warnings: + self.hook.text_generation_model_predict( + project_id=GCP_PROJECT, + location=GCP_LOCATION, + prompt=TEST_PROMPT, + pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL, + temperature=TEST_TEMPERATURE, + max_output_tokens=TEST_MAX_OUTPUT_TOKENS, + top_p=TEST_TOP_P, + top_k=TEST_TOP_K, + ) + assert_warning("generative_model_generate_content", warnings) @mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model")) def test_text_embedding_model_get_embeddings(self, mock_model) -> None: diff --git a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py index 5bdb04cb3edb3..709e5d1f78402 100644 --- a/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py +++ b/providers/tests/google/cloud/operators/vertex_ai/test_generative_model.py @@ -278,28 +278,46 @@ def test_execute(self, mock_hook): class TestVertexAITextGenerationModelPredictOperator: + prompt = "In 10 words or less, what is Apache Airflow?" + pretrained_model = "text-bison" + temperature = 0.0 + max_output_tokens = 256 + top_p = 0.8 + top_k = 40 + + def test_deprecation_warning(self): + with pytest.warns(AirflowProviderDeprecationWarning) as warnings: + TextGenerationModelPredictOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + prompt=self.prompt, + pretrained_model=self.pretrained_model, + temperature=self.temperature, + max_output_tokens=self.max_output_tokens, + top_p=self.top_p, + top_k=self.top_k, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + assert_warning("GenerativeModelGenerateContentOperator", warnings) + @mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook")) def test_execute(self, mock_hook): - prompt = "In 10 words or less, what is Apache Airflow?" - pretrained_model = "text-bison" - temperature = 0.0 - max_output_tokens = 256 - top_p = 0.8 - top_k = 40 - - op = TextGenerationModelPredictOperator( - task_id=TASK_ID, - project_id=GCP_PROJECT, - location=GCP_LOCATION, - prompt=prompt, - pretrained_model=pretrained_model, - temperature=temperature, - max_output_tokens=max_output_tokens, - top_p=top_p, - top_k=top_k, - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + with pytest.warns(AirflowProviderDeprecationWarning): + op = TextGenerationModelPredictOperator( + task_id=TASK_ID, + project_id=GCP_PROJECT, + location=GCP_LOCATION, + prompt=self.prompt, + pretrained_model=self.pretrained_model, + temperature=self.temperature, + max_output_tokens=self.max_output_tokens, + top_p=self.top_p, + top_k=self.top_k, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, @@ -308,12 +326,12 @@ def test_execute(self, mock_hook): mock_hook.return_value.text_generation_model_predict.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, - prompt=prompt, - pretrained_model=pretrained_model, - temperature=temperature, - max_output_tokens=max_output_tokens, - top_p=top_p, - top_k=top_k, + prompt=self.prompt, + pretrained_model=self.pretrained_model, + temperature=self.temperature, + max_output_tokens=self.max_output_tokens, + top_p=self.top_p, + top_k=self.top_k, ) diff --git a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py index 4384626999d0a..bafb361bc476a 100644 --- a/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py +++ b/providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py @@ -36,7 +36,6 @@ GenerativeModelGenerateContentOperator, RunEvaluationOperator, TextEmbeddingModelGetEmbeddingsOperator, - TextGenerationModelPredictOperator, ) PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") @@ -44,7 +43,6 @@ REGION = "us-central1" PROMPT = "In 10 words or less, why is Apache Airflow amazing?" CONTENTS = [PROMPT] -LANGUAGE_MODEL = "text-bison" TEXT_EMBEDDING_MODEL = "textembedding-gecko" MULTIMODAL_MODEL = "gemini-pro" MULTIMODAL_VISION_MODEL = "gemini-pro-vision" @@ -117,16 +115,6 @@ catchup=False, tags=["example", "vertex_ai", "generative_model"], ) as dag: - # [START how_to_cloud_vertex_ai_text_generation_model_predict_operator] - predict_task = TextGenerationModelPredictOperator( - task_id="predict_task", - project_id=PROJECT_ID, - location=REGION, - prompt=PROMPT, - pretrained_model=LANGUAGE_MODEL, - ) - # [END how_to_cloud_vertex_ai_text_generation_model_predict_operator] - # [START how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator] generate_embeddings_task = TextEmbeddingModelGetEmbeddingsOperator( task_id="generate_embeddings_task", diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 140e5d2d15098..de4fca69e51ae 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -353,6 +353,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator", "airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator", "airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator", + "airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator", "airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator", "airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator", "airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator", @@ -367,6 +368,12 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator", "airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator", "airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator", + "airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator", + "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator", + "airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator", + "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator", + "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator", + "airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator", "airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator", "airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator", "airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360DownloadReportV2Operator", @@ -385,7 +392,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest } MISSING_EXAMPLES_FOR_CLASSES = { - "airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator", "airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator", "airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator", "airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator", @@ -394,11 +400,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.vertex_ai.auto_ml.AutoMLTrainingJobBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.UpdateEndpointOperator", "airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator", - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator", - "airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator", } ASSETS_NOT_REQUIRED = {