Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate VertexAI PaLM text generative model #136

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -573,16 +573,6 @@ To get a pipeline job list you can use
Interacting with Generative AI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To generate a prediction via language model you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under ``model_response`` key.

.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
:end-before: [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]

To generate text embeddings you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under ``model_response`` key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
class GenerativeModelHook(GoogleBaseHook):
"""Hook for Google Cloud Vertex AI Generative Model APIs."""

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.get_generative_model",
category=AirflowProviderDeprecationWarning,
)
def get_text_generation_model(self, pretrained_model: str):
"""Return a Model Garden Model object based on Text Generation."""
model = TextGenerationModel.from_pretrained(pretrained_model)
Expand Down Expand Up @@ -275,6 +280,11 @@ def prompt_multimodal_model_with_media(

return response.text

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.generative_model_generate_content",
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def text_generation_model_predict(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ def execute(self, context: Context):
return response


@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelGenerateContentOperator",
category=AirflowProviderDeprecationWarning,
)
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
"""
Uses the Vertex AI PaLM API to generate natural language text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,18 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model"))
def test_text_generation_model_predict(self, mock_model) -> None:
self.hook.text_generation_model_predict(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
mock_model.assert_called_once_with(TEST_LANGUAGE_PRETRAINED_MODEL)
mock_model.return_value.predict.assert_called_once_with(
prompt=TEST_PROMPT,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
self.hook.text_generation_model_predict(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
assert_warning("generative_model_generate_content", warnings)

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model"))
def test_text_embedding_model_get_embeddings(self, mock_model) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,46 @@ def test_execute(self, mock_hook):


class TestVertexAITextGenerationModelPredictOperator:
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "text-bison"
temperature = 0.0
max_output_tokens = 256
top_p = 0.8
top_k = 40

def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
assert_warning("GenerativeModelGenerateContentOperator", warnings)

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "text-bison"
temperature = 0.0
max_output_tokens = 256
top_p = 0.8
top_k = 40

op = TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
pretrained_model=pretrained_model,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(AirflowProviderDeprecationWarning):
op = TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -308,12 +326,12 @@ def test_execute(self, mock_hook):
mock_hook.return_value.text_generation_model_predict.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
pretrained_model=pretrained_model,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
TextEmbeddingModelGetEmbeddingsOperator,
TextGenerationModelPredictOperator,
)

PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
DAG_ID = "vertex_ai_generative_model_dag"
REGION = "us-central1"
PROMPT = "In 10 words or less, why is Apache Airflow amazing?"
CONTENTS = [PROMPT]
LANGUAGE_MODEL = "text-bison"
TEXT_EMBEDDING_MODEL = "textembedding-gecko"
MULTIMODAL_MODEL = "gemini-pro"
MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
Expand Down Expand Up @@ -117,16 +115,6 @@
catchup=False,
tags=["example", "vertex_ai", "generative_model"],
) as dag:
# [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
predict_task = TextGenerationModelPredictOperator(
task_id="predict_task",
project_id=PROJECT_ID,
location=REGION,
prompt=PROMPT,
pretrained_model=LANGUAGE_MODEL,
)
# [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]

# [START how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
generate_embeddings_task = TextEmbeddingModelGetEmbeddingsOperator(
task_id="generate_embeddings_task",
Expand Down
13 changes: 7 additions & 6 deletions tests/always/test_project_structure.py
moiseenkov marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
"airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator",
"airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator",
Expand All @@ -367,6 +368,12 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360DownloadReportV2Operator",
Expand All @@ -385,7 +392,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}

MISSING_EXAMPLES_FOR_CLASSES = {
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator",
"airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator",
"airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator",
Expand All @@ -394,11 +400,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.vertex_ai.auto_ml.AutoMLTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.UpdateEndpointOperator",
"airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
}

ASSETS_NOT_REQUIRED = {
Expand Down