Skip to content

Commit

Permalink
Add CountTokensOperator for Google Generative AI CountTokensAPI (apac…
Browse files Browse the repository at this point in the history
…he#41908)

* Add CountTokensOperator for Google Generative AI CountTokensAPI

* Update system test DAG with correct arguments
  • Loading branch information
CYarros10 authored Sep 1, 2024
1 parent 2823acd commit 7cf54a7
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 4 deletions.
34 changes: 32 additions & 2 deletions airflow/providers/google/cloud/hooks/vertex_ai/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook

if TYPE_CHECKING:
from google.cloud.aiplatform_v1 import types
from google.cloud.aiplatform_v1 import types as types_v1
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1


class GenerativeModelHook(GoogleBaseHook):
Expand Down Expand Up @@ -367,7 +368,7 @@ def supervised_fine_tuning_train(
adapter_size: int | None = None,
learning_rate_multiplier: float | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> types.TuningJob:
) -> types_v1.TuningJob:
"""
Use the Supervised Fine Tuning API to create a tuning job.
Expand Down Expand Up @@ -406,3 +407,32 @@ def supervised_fine_tuning_train(
sft_tuning_job.refresh()

return sft_tuning_job

@GoogleBaseHook.fallback_to_default_project_id
def count_tokens(
self,
contents: list,
location: str,
pretrained_model: str = "gemini-pro",
project_id: str = PROVIDE_PROJECT_ID,
) -> types_v1beta1.CountTokensResponse:
"""
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

model = self.get_generative_model(pretrained_model)
response = model.count_tokens(
contents=contents,
)

return response
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from typing import TYPE_CHECKING, Sequence

from google.cloud.aiplatform_v1 import types
from google.cloud.aiplatform_v1 import types as types_v1
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
Expand Down Expand Up @@ -665,4 +666,73 @@ def execute(self, context: Context):
self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)

return types.TuningJob.to_dict(response)
return types_v1.TuningJob.to_dict(response)


class CountTokensOperator(GoogleCloudBaseOperator):
"""
Use the Vertex AI Count Tokens API to calculate the number of input tokens before sending a request to the Gemini API.
:param project_id: Required. The ID of the Google Cloud project that the
service belongs to (templated).
:param contents: Required. The multi-part content of a message that a user or a program
gives to the generative model, in order to elicit a specific response.
:param location: Required. The ID of the Google Cloud location that the
service belongs to (templated).
:param system_instruction: Optional. Instructions for the model to steer it toward better
performance. For example, "Answer as concisely as possible"
:param pretrained_model: By default uses the pre-trained model `gemini-pro`,
supporting prompts with text-only input, including natural language
tasks, multi-turn text and code chat, and code generation. It can
output text and code.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")

def __init__(
self,
*,
project_id: str,
contents: list,
location: str,
pretrained_model: str = "gemini-pro",
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.location = location
self.contents = contents
self.pretrained_model = pretrained_model
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
self.hook = GenerativeModelHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
response = self.hook.count_tokens(
project_id=self.project_id,
location=self.location,
contents=self.contents,
pretrained_model=self.pretrained_model,
)

self.log.info("Total tokens: %s", response.total_tokens)
self.log.info("Total billable characters: %s", response.total_billable_characters)

self.xcom_push(context, key="total_tokens", value=response.total_tokens)
self.xcom_push(context, key="total_billable_characters", value=response.total_billable_characters)

return types_v1beta1.CountTokensResponse.to_dict(response)
11 changes: 11 additions & 0 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,17 @@ The operator returns the tuned model's endpoint name in :ref:`XCom <concepts:xco
:start-after: [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
:end-before: [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]


To calculates the number of input tokens before sending a request to the Gemini API you can use:
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.CountTokensOperator`.
The operator returns the total tokens in :ref:`XCom <concepts:xcom>` under ``total_tokens`` key.

.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_count_tokens_operator]
:end-before: [END how_to_cloud_vertex_ai_count_tokens_operator]

Reference
^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,16 @@ def test_supervised_fine_tuning_train(self, mock_sft_train) -> None:
learning_rate_multiplier=None,
tuned_model_display_name=None,
)

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_generative_model"))
def test_count_tokens(self, mock_model) -> None:
self.hook.count_tokens(
project_id=GCP_PROJECT,
contents=TEST_CONTENTS,
location=GCP_LOCATION,
pretrained_model=TEST_MULTIMODAL_PRETRAINED_MODEL,
)
mock_model.assert_called_once_with(TEST_MULTIMODAL_PRETRAINED_MODEL)
mock_model.return_value.count_tokens.assert_called_once_with(
contents=TEST_CONTENTS,
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

# For no Pydantic environment, we need to skip the tests
pytest.importorskip("google.cloud.aiplatform_v1")
pytest.importorskip("google.cloud.aiplatform_v1beta1")
vertexai = pytest.importorskip("vertexai.generative_models")
from vertexai.generative_models import HarmBlockThreshold, HarmCategory, Tool, grounding

from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
CountTokensOperator,
GenerateTextEmbeddingsOperator,
GenerativeModelGenerateContentOperator,
PromptLanguageModelOperator,
Expand Down Expand Up @@ -417,3 +419,32 @@ def test_execute(
tuned_model_display_name=None,
validation_dataset=None,
)


class TestVertexAICountTokensOperator:
@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
@mock.patch("google.cloud.aiplatform_v1beta1.types.CountTokensResponse.to_dict")
def test_execute(self, to_dict_mock, mock_hook):
contents = ["In 10 words or less, what is Apache Airflow?"]
pretrained_model = "gemini-pro"

op = CountTokensOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
contents=contents,
pretrained_model=pretrained_model,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.count_tokens.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
contents=contents,
pretrained_model=pretrained_model,
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
CountTokensOperator,
GenerativeModelGenerateContentOperator,
TextEmbeddingModelGetEmbeddingsOperator,
TextGenerationModelPredictOperator,
Expand Down Expand Up @@ -84,6 +85,16 @@
)
# [END how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]

# [START how_to_cloud_vertex_ai_count_tokens_operator]
count_tokens_task = CountTokensOperator(
task_id="count_tokens_task",
project_id=PROJECT_ID,
contents=CONTENTS,
location=REGION,
pretrained_model=MULTIMODAL_MODEL,
)
# [END how_to_cloud_vertex_ai_count_tokens_operator]

# [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
generate_content_task = GenerativeModelGenerateContentOperator(
task_id="generate_content_task",
Expand Down

0 comments on commit 7cf54a7

Please sign in to comment.