diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py index 4f3743f95c376..6c4b2064c10ba 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py @@ -19,7 +19,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import asyncio +from typing import TYPE_CHECKING, Any, Sequence from deprecated import deprecated from google.api_core.client_options import ClientOptions @@ -31,15 +32,24 @@ datasets, models, ) -from google.cloud.aiplatform_v1 import JobServiceClient, PipelineServiceClient +from google.cloud.aiplatform_v1 import ( + JobServiceAsyncClient, + JobServiceClient, + JobState, + PipelineServiceAsyncClient, + PipelineServiceClient, + PipelineState, + types, +) from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry + from google.api_core.retry import AsyncRetry, Retry + from google.auth.credentials import Credentials from google.cloud.aiplatform_v1.services.job_service.pagers import ListCustomJobsPager from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ( ListPipelineJobsPager, @@ -168,7 +178,7 @@ def get_custom_python_package_training_job( training_encryption_spec_key_name: str | None = None, model_encryption_spec_key_name: str | None = None, staging_bucket: str | None = None, - ): + ) -> CustomPythonPackageTrainingJob: """Return CustomPythonPackageTrainingJob object.""" return CustomPythonPackageTrainingJob( display_name=display_name, @@ -218,7 +228,7 @@ def get_custom_training_job( training_encryption_spec_key_name: str | None = None, model_encryption_spec_key_name: str | None = None, staging_bucket: str | None = None, - ): + ) -> CustomTrainingJob: """Return CustomTrainingJob object.""" return CustomTrainingJob( display_name=display_name, @@ -246,10 +256,15 @@ def get_custom_training_job( ) @staticmethod - def extract_model_id(obj: dict) -> str: + def extract_model_id(obj: dict[str, Any]) -> str: """Return unique id of the Model.""" return obj["name"].rpartition("/")[-1] + @staticmethod + def extract_model_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str: + """Return a unique Model id from a serialized TrainingPipeline proto.""" + return training_pipeline["model_to_upload"]["name"].rpartition("/")[-1] + @staticmethod def extract_training_id(resource_name: str) -> str: """Return unique id of the Training pipeline.""" @@ -260,6 +275,11 @@ def extract_custom_job_id(custom_job_name: str) -> str: """Return unique id of the Custom Job pipeline.""" return custom_job_name.rpartition("/")[-1] + @staticmethod + def extract_custom_job_id_from_training_pipeline(training_pipeline: dict[str, Any]) -> str: + """Return a unique Custom Job id from a serialized TrainingPipeline proto.""" + return training_pipeline["training_task_metadata"]["backingCustomJob"].rpartition("/")[-1] + def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Wait for long-lasting operation to complete.""" try: @@ -310,7 +330,7 @@ def _run_job( model_version_aliases: list[str] | None = None, model_version_description: str | None = None, ) -> tuple[models.Model | None, str, str]: - """Run Job for training pipeline.""" + """Run a training pipeline job and wait until its completion.""" model = job.run( dataset=dataset, annotation_schema_uri=annotation_schema_uri, @@ -1754,79 +1774,1181 @@ def create_custom_training_job( return model, training_id, custom_job_id @GoogleBaseHook.fallback_to_default_project_id - @deprecated( - reason="Please use `PipelineJobHook.delete_pipeline_job`", - category=AirflowProviderDeprecationWarning, - ) - def delete_pipeline_job( + def submit_custom_container_training_job( self, + *, project_id: str, region: str, - pipeline_job: str, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: + display_name: str, + container_uri: str, + command: Sequence[str] = [], + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + parent_model: str | None = None, + is_default_version: bool | None = None, + model_version_aliases: list[str] | None = None, + model_version_description: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, + # RUN + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, + ) -> CustomContainerTrainingJob: """ - Delete a PipelineJob. - - This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method. + Create and submit a Custom Container Training Job pipeline, then exit without waiting for it to complete. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. - :param retry: Designation of what errors, if any, should be retried. - :param timeout: The timeout for this request. - :param metadata: Strings which should be sent along with the request as metadata. - """ - client = self.get_pipeline_service_client(region) - name = client.pipeline_job_path(project_id, region, pipeline_job) + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param command: The command to be invoked when the container is started. + It overrides the entrypoint instruction in Dockerfile when provided + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param parent_model: Optional. The resource name or model ID of an existing model. + The new model uploaded by this job will be a version of `parent_model`. + Only set this field when training a new version of an existing model. + :param is_default_version: Optional. When set to True, the newly uploaded model version will + automatically have alias "default" included. Subsequent uses of + the model produced by this job without a version specified will + use this "default" version. + When set to False, the "default" alias will not be moved. + Actions targeting the model version produced by this job will need + to specifically reference this version by ID or alias. + New model uploads, i.e. version 1, will always be "default" aliased. + :param model_version_aliases: Optional. User provided version aliases so that the model version + uploaded by this job can be referenced via alias instead of + auto-generated version ID. A default version alias will be created + for the first version of the model. + The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9] + :param model_version_description: Optional. The description of the model version + being uploaded by this job. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. - result = client.delete_pipeline_job( - request={ - "name": name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result + If set, this TrainingPipeline will be secured by this key. - @GoogleBaseHook.fallback_to_default_project_id - def delete_training_pipeline( - self, - project_id: str, - region: str, - training_pipeline: str, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Delete a TrainingPipeline. + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. - :param project_id: Required. The ID of the Google Cloud project that the service belongs to. - :param region: Required. The ID of the Google Cloud region that the service belongs to. - :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. - :param retry: Designation of what errors, if any, should be retried. - :param timeout: The timeout for this request. - :param metadata: Strings which should be sent along with the request as metadata. - """ - client = self.get_pipeline_service_client(region) - name = client.training_pipeline_path(project_id, region, training_pipeline) + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) - result = client.delete_training_pipeline( - request={ - "name": name, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. - @GoogleBaseHook.fallback_to_default_project_id - def delete_custom_job( + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + """ + self._job = self.get_custom_container_training_job( + project=project_id, + location=region, + display_name=display_name, + container_uri=container_uri, + command=command, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomContainerTrainingJob instance creation failed.") + + self._job.submit( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + parent_model=parent_model, + is_default_version=is_default_version, + model_version_aliases=model_version_aliases, + model_version_description=model_version_description, + sync=False, + ) + return self._job + + @GoogleBaseHook.fallback_to_default_project_id + def submit_custom_python_package_training_job( + self, + *, + project_id: str, + region: str, + display_name: str, + python_package_gcs_uri: str, + python_module_name: str, + container_uri: str, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, + # RUN + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, + parent_model: str | None = None, + is_default_version: bool | None = None, + model_version_aliases: list[str] | None = None, + model_version_description: str | None = None, + ) -> CustomPythonPackageTrainingJob: + """ + Create and submit a Custom Python Package Training Job pipeline, then exit without waiting for it to complete. + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param python_package_gcs_uri: Required: GCS location of the training python package. + :param python_module_name: Required: The module name of the training python package. + :param container_uri: Required: Uri of the training container image in the GCR. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param parent_model: Optional. The resource name or model ID of an existing model. + The new model uploaded by this job will be a version of `parent_model`. + Only set this field when training a new version of an existing model. + :param is_default_version: Optional. When set to True, the newly uploaded model version will + automatically have alias "default" included. Subsequent uses of + the model produced by this job without a version specified will + use this "default" version. + When set to False, the "default" alias will not be moved. + Actions targeting the model version produced by this job will need + to specifically reference this version by ID or alias. + New model uploads, i.e. version 1, will always be "default" aliased. + :param model_version_aliases: Optional. User provided version aliases so that the model version + uploaded by this job can be referenced via alias instead of + auto-generated version ID. A default version alias will be created + for the first version of the model. + The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9] + :param model_version_description: Optional. The description of the model version + being uploaded by this job. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + """ + self._job = self.get_custom_python_package_training_job( + project=project_id, + location=region, + display_name=display_name, + python_package_gcs_uri=python_package_gcs_uri, + python_module_name=python_module_name, + container_uri=container_uri, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomPythonPackageTrainingJob instance creation failed.") + + self._job.run( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + parent_model=parent_model, + is_default_version=is_default_version, + model_version_aliases=model_version_aliases, + model_version_description=model_version_description, + sync=False, + ) + + return self._job + + @GoogleBaseHook.fallback_to_default_project_id + def submit_custom_training_job( + self, + *, + project_id: str, + region: str, + display_name: str, + script_path: str, + container_uri: str, + requirements: Sequence[str] | None = None, + model_serving_container_image_uri: str | None = None, + model_serving_container_predict_route: str | None = None, + model_serving_container_health_route: str | None = None, + model_serving_container_command: Sequence[str] | None = None, + model_serving_container_args: Sequence[str] | None = None, + model_serving_container_environment_variables: dict[str, str] | None = None, + model_serving_container_ports: Sequence[int] | None = None, + model_description: str | None = None, + model_instance_schema_uri: str | None = None, + model_parameters_schema_uri: str | None = None, + model_prediction_schema_uri: str | None = None, + parent_model: str | None = None, + is_default_version: bool | None = None, + model_version_aliases: list[str] | None = None, + model_version_description: str | None = None, + labels: dict[str, str] | None = None, + training_encryption_spec_key_name: str | None = None, + model_encryption_spec_key_name: str | None = None, + staging_bucket: str | None = None, + # RUN + dataset: None + | ( + datasets.ImageDataset | datasets.TabularDataset | datasets.TextDataset | datasets.VideoDataset + ) = None, + annotation_schema_uri: str | None = None, + model_display_name: str | None = None, + model_labels: dict[str, str] | None = None, + base_output_dir: str | None = None, + service_account: str | None = None, + network: str | None = None, + bigquery_destination: str | None = None, + args: list[str | float | int] | None = None, + environment_variables: dict[str, str] | None = None, + replica_count: int = 1, + machine_type: str = "n1-standard-4", + accelerator_type: str = "ACCELERATOR_TYPE_UNSPECIFIED", + accelerator_count: int = 0, + boot_disk_type: str = "pd-ssd", + boot_disk_size_gb: int = 100, + training_fraction_split: float | None = None, + validation_fraction_split: float | None = None, + test_fraction_split: float | None = None, + training_filter_split: str | None = None, + validation_filter_split: str | None = None, + test_filter_split: str | None = None, + predefined_split_column_name: str | None = None, + timestamp_split_column_name: str | None = None, + tensorboard: str | None = None, + ) -> CustomTrainingJob: + """ + Create and submit a Custom Training Job pipeline, then exit without waiting for it to complete. + + Neither the training model nor backing custom job are available at the moment when the training + pipeline is submitted, both are created only after a period of time. Therefore, it is not possible + to extract and return them in this method, this should be done with a separate client request. + + :param display_name: Required. The user-defined name of this TrainingPipeline. + :param script_path: Required. Local path to training script. + :param container_uri: Required: Uri of the training container image in the GCR. + :param requirements: List of python packages dependencies of script. + :param model_serving_container_image_uri: If the training produces a managed Vertex AI Model, the URI + of the Model serving container suitable for serving the model produced by the + training script. + :param model_serving_container_predict_route: If the training produces a managed Vertex AI Model, An + HTTP path to send prediction requests to the container, and which must be supported + by it. If not specified a default HTTP path will be used by Vertex AI. + :param model_serving_container_health_route: If the training produces a managed Vertex AI Model, an + HTTP path to send health check requests to the container, and which must be supported + by it. If not specified a standard HTTP path will be used by AI Platform. + :param model_serving_container_command: The command with which the container is run. Not executed + within a shell. The Docker image's ENTRYPOINT is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. + :param model_serving_container_args: The arguments to the command. The Docker image's CMD is used if + this is not provided. Variable references $(VAR_NAME) are expanded using the + container's environment. If a variable cannot be resolved, the reference + in the input string will be unchanged. The $(VAR_NAME) syntax can be + escaped with a double $$, ie: $$(VAR_NAME). Escaped references will + never be expanded, regardless of whether the variable exists or not. + :param model_serving_container_environment_variables: The environment variables that are to be + present in the container. Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + :param model_serving_container_ports: Declaration of ports that are exposed by the container. This + field is primarily informational, it gives Vertex AI information about the + network connections the container uses. Listing or not a port here has + no impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible from + the network. + :param model_description: The description of the Model. + :param model_instance_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single instance, which + are used in + ``PredictRequest.instances``, + ``ExplainRequest.instances`` + and + ``BatchPredictionJob.input_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param model_parameters_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the parameters of prediction and + explanation via + ``PredictRequest.parameters``, + ``ExplainRequest.parameters`` + and + ``BatchPredictionJob.model_parameters``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform, if no parameters are supported it is set to an + empty string. Note: The URI given on output will be + immutable and probably different, including the URI scheme, + than the one given on input. The output URI will point to a + location where the user only has a read access. + :param model_prediction_schema_uri: Optional. Points to a YAML file stored on Google Cloud + Storage describing the format of a single prediction + produced by this Model, which are returned via + ``PredictResponse.predictions``, + ``ExplainResponse.explanations``, + and + ``BatchPredictionJob.output_config``. + The schema is defined as an OpenAPI 3.0.2 `Schema + Object `__. + AutoML Models always have this field populated by AI + Platform. Note: The URI given on output will be immutable + and probably different, including the URI scheme, than the + one given on input. The output URI will point to a location + where the user only has a read access. + :param parent_model: Optional. The resource name or model ID of an existing model. + The new model uploaded by this job will be a version of `parent_model`. + Only set this field when training a new version of an existing model. + :param is_default_version: Optional. When set to True, the newly uploaded model version will + automatically have alias "default" included. Subsequent uses of + the model produced by this job without a version specified will + use this "default" version. + When set to False, the "default" alias will not be moved. + Actions targeting the model version produced by this job will need + to specifically reference this version by ID or alias. + New model uploads, i.e. version 1, will always be "default" aliased. + :param model_version_aliases: Optional. User provided version aliases so that the model version + uploaded by this job can be referenced via alias instead of + auto-generated version ID. A default version alias will be created + for the first version of the model. + The format is [a-z][a-zA-Z0-9-]{0,126}[a-z0-9] + :param model_version_description: Optional. The description of the model version + being uploaded by this job. + :param project_id: Project to run training in. + :param region: Location to run training in. + :param labels: Optional. The labels with user-defined metadata to + organize TrainingPipelines. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param training_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the training pipeline. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, this TrainingPipeline will be secured by this key. + + Note: Model trained by this TrainingPipeline is also secured + by this key if ``model_to_upload`` is not set separately. + :param model_encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the model. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If set, the trained Model will be secured by this key. + :param staging_bucket: Bucket used to stage source and training artifacts. + :param dataset: Vertex AI to fit this training against. + :param annotation_schema_uri: Google Cloud Storage URI points to a YAML file describing + annotation schema. The schema is defined as an OpenAPI 3.0.2 + [Schema Object] + (https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.0.2.md#schema-object) + + Only Annotations that both match this schema and belong to + DataItems not ignored by the split method are used in + respectively training, validation or test role, depending on + the role of the DataItem they are on. + + When used in conjunction with + ``annotations_filter``, + the Annotations used for training are filtered by both + ``annotations_filter`` + and + ``annotation_schema_uri``. + :param model_display_name: If the script produces a managed Vertex AI Model. The display name of + the Model. The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + + If not provided upon creation, the job's display_name is used. + :param model_labels: Optional. The labels with user-defined metadata to + organize your Models. + Label keys and values can be no longer than 64 + characters, can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + See https://goo.gl/xmQnxf for more information + and examples of labels. + :param base_output_dir: GCS output directory of job. If not provided a timestamped directory in the + staging directory will be used. + + Vertex AI sets the following environment variables when it runs your training code: + + - AIP_MODEL_DIR: a Cloud Storage URI of a directory intended for saving model artifacts, + i.e. /model/ + - AIP_CHECKPOINT_DIR: a Cloud Storage URI of a directory intended for saving checkpoints, + i.e. /checkpoints/ + - AIP_TENSORBOARD_LOG_DIR: a Cloud Storage URI of a directory intended for saving TensorBoard + logs, i.e. /logs/ + :param service_account: Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + :param network: The full name of the Compute Engine network to which the job + should be peered. + Private services access must already be configured for the network. + If left unspecified, the job is not peered with any network. + :param bigquery_destination: Provide this field if `dataset` is a BiqQuery dataset. + The BigQuery project location where the training data is to + be written to. In the given project a new dataset is created + with name + ``dataset___`` + where timestamp is in YYYY_MM_DDThh_mm_ss_sssZ format. All + training input data will be written into that dataset. In + the dataset three tables will be created, ``training``, + ``validation`` and ``test``. + + - AIP_DATA_FORMAT = "bigquery". + - AIP_TRAINING_DATA_URI ="bigquery_destination.dataset_*.training" + - AIP_VALIDATION_DATA_URI = "bigquery_destination.dataset_*.validation" + - AIP_TEST_DATA_URI = "bigquery_destination.dataset_*.test" + :param args: Command line arguments to be passed to the Python script. + :param environment_variables: Environment variables to be passed to the container. + Should be a dictionary where keys are environment variable names + and values are environment variable values for those names. + At most 10 environment variables can be specified. + The Name of the environment variable must be unique. + :param replica_count: The number of worker replicas. If replica count = 1 then one chief + replica will be provisioned. If replica_count > 1 the remainder will be + provisioned as a worker replica pool. + :param machine_type: The type of machine to use for training. + :param accelerator_type: Hardware accelerator type. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, + NVIDIA_TESLA_T4 + :param accelerator_count: The number of accelerators to attach to a worker replica. + :param boot_disk_type: Type of the boot disk, default is `pd-ssd`. + Valid values: `pd-ssd` (Persistent Disk Solid State Drive) or + `pd-standard` (Persistent Disk Hard Disk Drive). + :param boot_disk_size_gb: Size in GB of the boot disk, default is 100GB. + boot disk size must be within the range of [100, 64000]. + :param training_fraction_split: Optional. The fraction of the input data that is to be used to train + the Model. This is ignored if Dataset is not provided. + :param validation_fraction_split: Optional. The fraction of the input data that is to be used to + validate the Model. This is ignored if Dataset is not provided. + :param test_fraction_split: Optional. The fraction of the input data that is to be used to evaluate + the Model. This is ignored if Dataset is not provided. + :param training_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to train the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param validation_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to validate the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param test_filter_split: Optional. A filter on DataItems of the Dataset. DataItems that match + this filter are used to test the Model. A filter with same syntax + as the one used in DatasetService.ListDataItems may be used. If a + single DataItem is matched by more than one of the FilterSplit filters, + then it is assigned to the first set that applies to it in the training, + validation, test order. This is ignored if Dataset is not provided. + :param predefined_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key (either the label's value or + value in the column) must be one of {``training``, + ``validation``, ``test``}, and it defines to which set the + given piece of data is assigned. If for a piece of data the + key is not present or has an invalid value, that piece is + ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param timestamp_split_column_name: Optional. The key is a name of one of the Dataset's data + columns. The value of the key values of the key (the values in + the column) must be in RFC 3339 `date-time` format, where + `time-offset` = `"Z"` (e.g. 1985-04-12T23:20:50.52Z). If for a + piece of data the key is not present or has an invalid value, + that piece is ignored by the pipeline. + + Supported only for tabular and time series Datasets. + :param tensorboard: Optional. The name of a Vertex AI resource to which this CustomJob will upload + logs. Format: + ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` + For more information on configuring your service account please visit: + https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training + """ + self._job = self.get_custom_training_job( + project=project_id, + location=region, + display_name=display_name, + script_path=script_path, + container_uri=container_uri, + requirements=requirements, + model_serving_container_image_uri=model_serving_container_image_uri, + model_serving_container_predict_route=model_serving_container_predict_route, + model_serving_container_health_route=model_serving_container_health_route, + model_serving_container_command=model_serving_container_command, + model_serving_container_args=model_serving_container_args, + model_serving_container_environment_variables=model_serving_container_environment_variables, + model_serving_container_ports=model_serving_container_ports, + model_description=model_description, + model_instance_schema_uri=model_instance_schema_uri, + model_parameters_schema_uri=model_parameters_schema_uri, + model_prediction_schema_uri=model_prediction_schema_uri, + labels=labels, + training_encryption_spec_key_name=training_encryption_spec_key_name, + model_encryption_spec_key_name=model_encryption_spec_key_name, + staging_bucket=staging_bucket, + ) + + if not self._job: + raise AirflowException("CustomTrainingJob instance creation failed.") + + self._job.submit( + dataset=dataset, + annotation_schema_uri=annotation_schema_uri, + model_display_name=model_display_name, + model_labels=model_labels, + base_output_dir=base_output_dir, + service_account=service_account, + network=network, + bigquery_destination=bigquery_destination, + args=args, + environment_variables=environment_variables, + replica_count=replica_count, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + boot_disk_type=boot_disk_type, + boot_disk_size_gb=boot_disk_size_gb, + training_fraction_split=training_fraction_split, + validation_fraction_split=validation_fraction_split, + test_fraction_split=test_fraction_split, + training_filter_split=training_filter_split, + validation_filter_split=validation_filter_split, + test_filter_split=test_filter_split, + predefined_split_column_name=predefined_split_column_name, + timestamp_split_column_name=timestamp_split_column_name, + tensorboard=tensorboard, + parent_model=parent_model, + is_default_version=is_default_version, + model_version_aliases=model_version_aliases, + model_version_description=model_version_description, + sync=False, + ) + return self._job + + @GoogleBaseHook.fallback_to_default_project_id + def delete_training_pipeline( + self, + project_id: str, + region: str, + training_pipeline: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete a TrainingPipeline. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param training_pipeline: Required. The name of the TrainingPipeline resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.training_pipeline_path(project_id, region, training_pipeline) + + result = client.delete_training_pipeline( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + @GoogleBaseHook.fallback_to_default_project_id + def delete_custom_job( self, project_id: str, region: str, @@ -2178,3 +3300,239 @@ def list_custom_jobs( metadata=metadata, ) return result + + @GoogleBaseHook.fallback_to_default_project_id + @deprecated( + reason="Please use `PipelineJobHook.delete_pipeline_job`", + category=AirflowProviderDeprecationWarning, + ) + def delete_pipeline_job( + self, + project_id: str, + region: str, + pipeline_job: str, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + ) -> Operation: + """ + Delete a PipelineJob. + + This method is deprecated, please use `PipelineJobHook.delete_pipeline_job` method. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_job: Required. The name of the PipelineJob resource to be deleted. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = self.get_pipeline_service_client(region) + name = client.pipeline_job_path(project_id, region, pipeline_job) + + result = client.delete_pipeline_job( + request={ + "name": name, + }, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + +class CustomJobAsyncHook(GoogleBaseAsyncHook): + """Async hook for Custom Job Service Client.""" + + sync_hook_class = CustomJobHook + JOB_COMPLETE_STATES = { + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_SUCCEEDED, + } + PIPELINE_COMPLETE_STATES = ( + PipelineState.PIPELINE_STATE_CANCELLED, + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_PAUSED, + PipelineState.PIPELINE_STATE_SUCCEEDED, + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ): + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + **kwargs, + ) + self._job: None | ( + CustomContainerTrainingJob | CustomPythonPackageTrainingJob | CustomTrainingJob + ) = None + + async def get_credentials(self) -> Credentials: + return (await self.get_sync_hook()).get_credentials() + + async def get_job_service_client( + self, + region: str | None = None, + ) -> JobServiceAsyncClient: + """Retrieve Vertex AI JobServiceAsyncClient object.""" + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") + else: + client_options = ClientOptions() + return JobServiceAsyncClient( + credentials=(await self.get_credentials()), + client_info=CLIENT_INFO, + client_options=client_options, + ) + + async def get_pipeline_service_client( + self, + region: str | None = None, + ) -> PipelineServiceAsyncClient: + """Retrieve Vertex AI PipelineServiceAsyncClient object.""" + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") + else: + client_options = ClientOptions() + return PipelineServiceAsyncClient( + credentials=(await self.get_credentials()), + client_info=CLIENT_INFO, + client_options=client_options, + ) + + async def get_custom_job( + self, + project_id: str, + location: str, + job_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault | None = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + client: JobServiceAsyncClient | None = None, + ) -> types.CustomJob: + """ + Get a CustomJob proto message from JobServiceAsyncClient. + + :param project_id: Required. The ID of the Google Cloud project that the job belongs to. + :param location: Required. The ID of the Google Cloud region that the job belongs to. + :param job_id: Required. The hyperparameter tuning job id. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + if not client: + client = await self.get_job_service_client(region=location) + job_name = client.custom_job_path(project_id, location, job_id) + result: types.CustomJob = await client.get_custom_job( + request={"name": job_name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return result + + async def get_training_pipeline( + self, + project_id: str, + location: str, + pipeline_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault | None = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + client: PipelineServiceAsyncClient | None = None, + ) -> types.TrainingPipeline: + """ + Get a TrainingPipeline proto message from PipelineServiceAsyncClient. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param pipeline_id: Required. The ID of the PipelineJob resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + if not client: + client = await self.get_pipeline_service_client(region=location) + pipeline_name = client.training_pipeline_path( + project=project_id, + location=location, + training_pipeline=pipeline_id, + ) + response: types.TrainingPipeline = await client.get_training_pipeline( + request={"name": pipeline_name}, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return response + + async def wait_for_custom_job( + self, + project_id: str, + location: str, + job_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + poll_interval: int = 10, + ) -> types.CustomJob: + client = await self.get_job_service_client(region=location) + while True: + try: + self.log.info("Requesting a custom job with id %s", job_id) + job: types.CustomJob = await self.get_custom_job( + project_id=project_id, + location=location, + job_id=job_id, + retry=retry, + timeout=timeout, + metadata=metadata, + client=client, + ) + except Exception as ex: + self.log.exception("Exception occurred while requesting job %s", job_id) + raise AirflowException(ex) + self.log.info("Status of the custom job %s is %s", job.name, job.state.name) + if job.state in self.JOB_COMPLETE_STATES: + return job + self.log.info("Sleeping for %s seconds.", poll_interval) + await asyncio.sleep(poll_interval) + + async def wait_for_training_pipeline( + self, + project_id: str, + location: str, + pipeline_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + poll_interval: int = 10, + ) -> types.TrainingPipeline: + """Make async calls to Vertex AI to check the training pipeline state until it is complete.""" + client = await self.get_pipeline_service_client(region=location) + while True: + try: + self.log.info("Requesting a training pipeline with id %s", pipeline_id) + pipeline: types.TrainingPipeline = await self.get_training_pipeline( + project_id=project_id, + location=location, + pipeline_id=pipeline_id, + retry=retry, + timeout=timeout, + metadata=metadata, + client=client, + ) + except Exception as ex: + self.log.exception("Exception occurred while requesting training pipeline %s", pipeline_id) + raise AirflowException(ex) + self.log.info("Status of the training pipeline %s is %s", pipeline.name, pipeline.state.name) + if pipeline.state in self.PIPELINE_COMPLETE_STATES: + return pipeline + self.log.info("Sleeping for %s seconds.", poll_interval) + await asyncio.sleep(poll_interval) diff --git a/airflow/providers/google/cloud/links/vertex_ai.py b/airflow/providers/google/cloud/links/vertex_ai.py index aa8d94c1663c8..8463510cd3919 100644 --- a/airflow/providers/google/cloud/links/vertex_ai.py +++ b/airflow/providers/google/cloud/links/vertex_ai.py @@ -25,7 +25,8 @@ VERTEX_AI_BASE_LINK = "/vertex-ai" VERTEX_AI_MODEL_LINK = ( - VERTEX_AI_BASE_LINK + "/locations/{region}/models/{model_id}/deploy?project={project_id}" + VERTEX_AI_BASE_LINK + + "/models/locations/{region}/models/{model_id}/versions/default/properties?project={project_id}" ) VERTEX_AI_MODEL_LIST_LINK = VERTEX_AI_BASE_LINK + "/models?project={project_id}" VERTEX_AI_MODEL_EXPORT_LINK = "/storage/browser/{bucket_name}/model-{model_id}?project={project_id}" diff --git a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py index 8802c3a26a05f..c0f1bd7c44e67 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/custom_job.py @@ -19,7 +19,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +import warnings +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence from deprecated import deprecated from google.api_core.exceptions import NotFound @@ -28,7 +30,8 @@ from google.cloud.aiplatform_v1.types.dataset import Dataset from google.cloud.aiplatform_v1.types.training_pipeline import TrainingPipeline -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook from airflow.providers.google.cloud.links.vertex_ai import ( VertexAIModelLink, @@ -36,9 +39,19 @@ VertexAITrainingPipelinesLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import ( + CustomContainerTrainingJobTrigger, + CustomPythonPackageTrainingJobTrigger, + CustomTrainingJobTrigger, +) if TYPE_CHECKING: from google.api_core.retry import Retry + from google.cloud.aiplatform import ( + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + CustomTrainingJob, + ) from airflow.utils.context import Context @@ -421,9 +434,6 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :param sync: Whether to execute the AI Platform job synchronously. If False, this method - will be executed in concurrent Future and any downstream object will - be immediately returned and synced when the Future has completed. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -433,6 +443,9 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the task in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + The default is 60 seconds. """ template_fields = ( @@ -442,7 +455,10 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator): "dataset_id", "impersonation_chain", ) - operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) + operator_extra_links = ( + VertexAIModelLink(), + VertexAITrainingLink(), + ) def __init__( self, @@ -452,6 +468,8 @@ def __init__( parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 60, **kwargs, ) -> None: super().__init__( @@ -462,12 +480,19 @@ def __init__( **kwargs, ) self.command = command + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context): - self.hook = CustomJobHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + warnings.warn( + "The 'sync' parameter is deprecated and will be removed after 01.10.2024.", + AirflowProviderDeprecationWarning, + stacklevel=2, ) + + if self.deferrable: + self.invoke_defer(context=context) + model, training_id, custom_job_id = self.hook.create_custom_container_training_job( project_id=self.project_id, region=self.region, @@ -539,6 +564,94 @@ def on_kill(self) -> None: if self.hook: self.hook.cancel_job() + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None: + if event["status"] == "error": + raise AirflowException(event["message"]) + result = event["job"] + model_id = self.hook.extract_model_id_from_training_pipeline(result) + custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result) + self.xcom_push(context, key="model_id", value=model_id) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + # push custom_job_id to xcom so it could be pulled by other tasks + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + return result + + def invoke_defer(self, context: Context) -> None: + custom_container_training_job_obj: CustomContainerTrainingJob = self.hook.submit_custom_container_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + command=self.command, + container_uri=self.container_uri, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + parent_model=self.parent_model, + is_default_version=self.is_default_version, + model_version_aliases=self.model_version_aliases, + model_version_description=self.model_version_description, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + ) + custom_container_training_job_obj.wait_for_resource_creation() + training_pipeline_id: str = custom_container_training_job_obj.name + self.xcom_push(context, key="training_id", value=training_pipeline_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + self.defer( + trigger=CustomContainerTrainingJobTrigger( + conn_id=self.gcp_conn_id, + project_id=self.project_id, + location=self.region, + job_id=training_pipeline_id, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + @cached_property + def hook(self) -> CustomJobHook: + return CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator): """Create Custom Python Package Training job. @@ -800,9 +913,6 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :param sync: Whether to execute the AI Platform job synchronously. If False, this method - will be executed in concurrent Future and any downstream object will - be immediately returned and synced when the Future has completed. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -812,6 +922,9 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the task in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + The default is 60 seconds. """ template_fields = ( @@ -831,6 +944,8 @@ def __init__( parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 60, **kwargs, ) -> None: super().__init__( @@ -842,12 +957,19 @@ def __init__( ) self.python_package_gcs_uri = python_package_gcs_uri self.python_module_name = python_module_name + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context): - self.hook = CustomJobHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + warnings.warn( + "The 'sync' parameter is deprecated and will be removed after 01.10.2024.", + AirflowProviderDeprecationWarning, + stacklevel=2, ) + + if self.deferrable: + self.invoke_defer(context=context) + model, training_id, custom_job_id = self.hook.create_custom_python_package_training_job( project_id=self.project_id, region=self.region, @@ -920,9 +1042,98 @@ def on_kill(self) -> None: if self.hook: self.hook.cancel_job() + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None: + if event["status"] == "error": + raise AirflowException(event["message"]) + result = event["job"] + model_id = self.hook.extract_model_id_from_training_pipeline(result) + custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result) + self.xcom_push(context, key="model_id", value=model_id) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + # push custom_job_id to xcom so it could be pulled by other tasks + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + return result + + def invoke_defer(self, context: Context) -> None: + custom_python_training_job_obj: CustomPythonPackageTrainingJob = self.hook.submit_custom_python_package_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + python_package_gcs_uri=self.python_package_gcs_uri, + python_module_name=self.python_module_name, + container_uri=self.container_uri, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + parent_model=self.parent_model, + is_default_version=self.is_default_version, + model_version_aliases=self.model_version_aliases, + model_version_description=self.model_version_description, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + ) + custom_python_training_job_obj.wait_for_resource_creation() + training_pipeline_id: str = custom_python_training_job_obj.name + self.xcom_push(context, key="training_id", value=training_pipeline_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + self.defer( + trigger=CustomPythonPackageTrainingJobTrigger( + conn_id=self.gcp_conn_id, + project_id=self.project_id, + location=self.region, + job_id=training_pipeline_id, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + @cached_property + def hook(self) -> CustomJobHook: + return CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): - """Create Custom Training job. + """Create a Custom Training Job pipeline. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. @@ -1181,9 +1392,6 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): ``projects/{project}/locations/{location}/tensorboards/{tensorboard}`` For more information on configuring your service account please visit: https://cloud.google.com/vertex-ai/docs/experiments/tensorboard-training - :param sync: Whether to execute the AI Platform job synchronously. If False, this method - will be executed in concurrent Future and any downstream object will - be immediately returned and synced when the Future has completed. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -1193,6 +1401,9 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the task in the deferrable mode. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + The default is 60 seconds. """ template_fields = ( @@ -1203,7 +1414,10 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator): "dataset_id", "impersonation_chain", ) - operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink()) + operator_extra_links = ( + VertexAIModelLink(), + VertexAITrainingLink(), + ) def __init__( self, @@ -1214,6 +1428,8 @@ def __init__( parent_model: str | None = None, impersonation_chain: str | Sequence[str] | None = None, dataset_id: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 60, **kwargs, ) -> None: super().__init__( @@ -1225,12 +1441,19 @@ def __init__( ) self.requirements = requirements self.script_path = script_path + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context): - self.hook = CustomJobHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + warnings.warn( + "The 'sync' parameter is deprecated and will be removed after 01.10.2024.", + AirflowProviderDeprecationWarning, + stacklevel=2, ) + + if self.deferrable: + self.invoke_defer(context=context) + model, training_id, custom_job_id = self.hook.create_custom_training_job( project_id=self.project_id, region=self.region, @@ -1303,6 +1526,95 @@ def on_kill(self) -> None: if self.hook: self.hook.cancel_job() + def execute_complete(self, context: Context, event: dict[str, Any]) -> dict[str, Any] | None: + if event["status"] == "error": + raise AirflowException(event["message"]) + result = event["job"] + model_id = self.hook.extract_model_id_from_training_pipeline(result) + custom_job_id = self.hook.extract_custom_job_id_from_training_pipeline(result) + self.xcom_push(context, key="model_id", value=model_id) + VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id) + # push custom_job_id to xcom so it could be pulled by other tasks + self.xcom_push(context, key="custom_job_id", value=custom_job_id) + return result + + def invoke_defer(self, context: Context) -> None: + custom_training_job_obj: CustomTrainingJob = self.hook.submit_custom_training_job( + project_id=self.project_id, + region=self.region, + display_name=self.display_name, + script_path=self.script_path, + container_uri=self.container_uri, + requirements=self.requirements, + model_serving_container_image_uri=self.model_serving_container_image_uri, + model_serving_container_predict_route=self.model_serving_container_predict_route, + model_serving_container_health_route=self.model_serving_container_health_route, + model_serving_container_command=self.model_serving_container_command, + model_serving_container_args=self.model_serving_container_args, + model_serving_container_environment_variables=self.model_serving_container_environment_variables, + model_serving_container_ports=self.model_serving_container_ports, + model_description=self.model_description, + model_instance_schema_uri=self.model_instance_schema_uri, + model_parameters_schema_uri=self.model_parameters_schema_uri, + model_prediction_schema_uri=self.model_prediction_schema_uri, + parent_model=self.parent_model, + is_default_version=self.is_default_version, + model_version_aliases=self.model_version_aliases, + model_version_description=self.model_version_description, + labels=self.labels, + training_encryption_spec_key_name=self.training_encryption_spec_key_name, + model_encryption_spec_key_name=self.model_encryption_spec_key_name, + staging_bucket=self.staging_bucket, + # RUN + dataset=Dataset(name=self.dataset_id) if self.dataset_id else None, + annotation_schema_uri=self.annotation_schema_uri, + model_display_name=self.model_display_name, + model_labels=self.model_labels, + base_output_dir=self.base_output_dir, + service_account=self.service_account, + network=self.network, + bigquery_destination=self.bigquery_destination, + args=self.args, + environment_variables=self.environment_variables, + replica_count=self.replica_count, + machine_type=self.machine_type, + accelerator_type=self.accelerator_type, + accelerator_count=self.accelerator_count, + boot_disk_type=self.boot_disk_type, + boot_disk_size_gb=self.boot_disk_size_gb, + training_fraction_split=self.training_fraction_split, + validation_fraction_split=self.validation_fraction_split, + test_fraction_split=self.test_fraction_split, + training_filter_split=self.training_filter_split, + validation_filter_split=self.validation_filter_split, + test_filter_split=self.test_filter_split, + predefined_split_column_name=self.predefined_split_column_name, + timestamp_split_column_name=self.timestamp_split_column_name, + tensorboard=self.tensorboard, + ) + custom_training_job_obj.wait_for_resource_creation() + training_pipeline_id: str = custom_training_job_obj.name + self.xcom_push(context, key="training_id", value=training_pipeline_id) + VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_pipeline_id) + self.defer( + trigger=CustomTrainingJobTrigger( + conn_id=self.gcp_conn_id, + project_id=self.project_id, + location=self.region, + job_id=training_pipeline_id, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + + @cached_property + def hook(self) -> CustomJobHook: + return CustomJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class DeleteCustomTrainingJobOperator(GoogleCloudBaseOperator): """ diff --git a/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py index 5dd0faef8aef1..91f3f50fddb57 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py @@ -102,7 +102,6 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param deferrable: If True, run the task in the deferrable mode. - Note that it requires calling the operator with `sync=False` parameter. :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. The default is 300 seconds. """ diff --git a/airflow/providers/google/cloud/triggers/vertex_ai.py b/airflow/providers/google/cloud/triggers/vertex_ai.py index 4a3816f8f6ae7..eeb9d15cc1d95 100644 --- a/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -29,6 +29,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.vertex_ai.batch_prediction_job import BatchPredictionJobAsyncHook +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( HyperparameterTuningJobAsyncHook, ) @@ -189,3 +190,96 @@ async def _wait_job(self) -> types.PipelineJob: poll_interval=self.poll_interval, ) return job + + +class CustomTrainingJobTrigger(BaseVertexAIJobTrigger): + """ + Make async calls to Vertex AI to check the state of a running custom training job. + + Return the job when it enters a completed state. + """ + + job_type_verbose_name = "Custom Training Job" + job_serializer_class = types.TrainingPipeline + statuses_success = { + PipelineState.PIPELINE_STATE_PAUSED, + PipelineState.PIPELINE_STATE_SUCCEEDED, + } + + @cached_property + def async_hook(self) -> CustomJobAsyncHook: + return CustomJobAsyncHook( + gcp_conn_id=self.conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def _wait_job(self) -> types.TrainingPipeline: + pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline( + project_id=self.project_id, + location=self.location, + pipeline_id=self.job_id, + poll_interval=self.poll_interval, + ) + return pipeline + + +class CustomContainerTrainingJobTrigger(BaseVertexAIJobTrigger): + """ + Make async calls to Vertex AI to check the state of a running custom container training job. + + Return the job when it enters a completed state. + """ + + job_type_verbose_name = "Custom Container Training Job" + job_serializer_class = types.TrainingPipeline + statuses_success = { + PipelineState.PIPELINE_STATE_PAUSED, + PipelineState.PIPELINE_STATE_SUCCEEDED, + } + + @cached_property + def async_hook(self) -> CustomJobAsyncHook: + return CustomJobAsyncHook( + gcp_conn_id=self.conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def _wait_job(self) -> types.TrainingPipeline: + pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline( + project_id=self.project_id, + location=self.location, + pipeline_id=self.job_id, + poll_interval=self.poll_interval, + ) + return pipeline + + +class CustomPythonPackageTrainingJobTrigger(BaseVertexAIJobTrigger): + """ + Make async calls to Vertex AI to check the state of a running custom python package training job. + + Return the job when it enters a completed state. + """ + + job_type_verbose_name = "Custom Python Package Training Job" + job_serializer_class = types.TrainingPipeline + statuses_success = { + PipelineState.PIPELINE_STATE_PAUSED, + PipelineState.PIPELINE_STATE_SUCCEEDED, + } + + @cached_property + def async_hook(self) -> CustomJobAsyncHook: + return CustomJobAsyncHook( + gcp_conn_id=self.conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def _wait_job(self) -> types.TrainingPipeline: + pipeline: types.TrainingPipeline = await self.async_hook.wait_for_training_pipeline( + project_id=self.project_id, + location=self.location, + pipeline_id=self.job_id, + poll_interval=self.poll_interval, + ) + return pipeline diff --git a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst index 24c51430d351c..c93ce54577530 100644 --- a/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst +++ b/docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst @@ -107,7 +107,7 @@ Preparation step For each operator you must prepare and create dataset. Then put dataset id to ``dataset_id`` parameter in operator. -How to run Container Training Job +How to run a Custom Container Training Job :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator` Before start running this Job you should create a docker image with training script inside. Documentation how to @@ -121,7 +121,16 @@ for container which will be created from this image in ``command`` parameter. :start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator] :end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] -How to run Python Package Training Job +The :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomContainerTrainingJobOperator` +also provides the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable] + +How to run a Python Package Training Job :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator` Before start running this Job you should create a python package with training script inside. Documentation how to @@ -135,10 +144,19 @@ parameter should has the name of script which will run your training task. :start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] :end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] -How to run Training Job +The :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator` +also provides the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable] + +How to run a Custom Training Job :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomTrainingJobOperator`. -For this Job you should put path to your local training script inside ``script_path`` parameter. +To create and run a Custom Training Job you should put the path to your local training script inside the ``script_path`` parameter. .. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py :language: python @@ -146,9 +164,17 @@ For this Job you should put path to your local training script inside ``script_p :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_operator] :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_operator] -Additionally, you can create new version of existing Training Job instead. In this case, the result will be new -version of existing Model instead of new Model created in Model Registry. This can be done by specifying -``parent_model`` parameter when running Training Job. +The same operation can be performed in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable] + +Additionally, you can create a new version of an existing Custom Training Job. It will replace the existing +Model with another version, instead of creating a new Model in the Model Registry. +This can be done by specifying the ``parent_model`` parameter when running a Custom Training Job. .. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py :language: python @@ -156,6 +182,14 @@ version of existing Model instead of new Model created in Model Registry. This c :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator] :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator] +The same operation can be performed in the deferrable mode: + +.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable] + :end-before: [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable] + You can get a list of Training Jobs using :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.ListCustomTrainingJobOperator`. diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py index 839fc2b3cc5e9..0e2d42fe26a09 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_custom_job.py @@ -17,18 +17,31 @@ # under the License. from __future__ import annotations +import asyncio from unittest import mock import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.aiplatform_v1 import JobServiceAsyncClient, PipelineServiceAsyncClient -from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import ( + CustomJobAsyncHook, + CustomJobHook, + JobState, + PipelineState, + types, +) from tests.providers.google.cloud.utils.base_gcp_mock import ( mock_base_gcp_hook_default_project_id, mock_base_gcp_hook_no_default_project_id, ) TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_IMPERSONATION_CHAIN = [ + "TEST", + "PERSONA", +] TEST_REGION: str = "test-region" TEST_PROJECT_ID: str = "test-project-id" TEST_PIPELINE_JOB: dict = {} @@ -40,6 +53,46 @@ CUSTOM_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.custom_job.{}" +@pytest.fixture +def test_async_hook(): + return CustomJobAsyncHook( + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + +@pytest.fixture +def pipeline_service_async_client(): + return PipelineServiceAsyncClient( + credentials=mock.MagicMock(), + ) + + +@pytest.fixture +def job_service_async_client(): + return JobServiceAsyncClient( + credentials=mock.MagicMock(), + ) + + +@pytest.fixture +def test_training_pipeline_name(pipeline_service_async_client): + return pipeline_service_async_client.training_pipeline_path( + project=TEST_PROJECT_ID, + location=TEST_REGION, + training_pipeline=TEST_PIPELINE_JOB_ID, + ) + + +@pytest.fixture +def test_custom_job_name(job_service_async_client): + return job_service_async_client.custom_job_path( + project=TEST_PROJECT_ID, + location=TEST_REGION, + custom_job=TEST_PIPELINE_JOB_ID, + ) + + class TestCustomJobWithDefaultProjectIdHook: def test_delegate_to_runtime_error(self): with pytest.raises(RuntimeError): @@ -462,3 +515,206 @@ def test_list_training_pipelines(self, mock_client) -> None: timeout=None, ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + +class TestCustomJobAsyncHook: + @pytest.mark.asyncio + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_get_training_pipeline( + self, mock_pipeline_service_client, test_async_hook, test_training_pipeline_name + ): + mock_pipeline_service_client.return_value.training_pipeline_path = mock.MagicMock( + return_value=test_training_pipeline_name + ) + await test_async_hook.get_training_pipeline( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + pipeline_id=TEST_PIPELINE_JOB_ID, + ) + mock_pipeline_service_client.assert_awaited_once_with(region=TEST_REGION) + mock_pipeline_service_client.return_value.get_training_pipeline.assert_awaited_once_with( + request={"name": test_training_pipeline_name}, + retry=DEFAULT, + timeout=DEFAULT, + metadata=(), + ) + + @pytest.mark.asyncio + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_job_service_client")) + async def test_get_custom_job( + self, + mock_get_job_service_client, + test_async_hook, + test_custom_job_name, + ): + mock_get_job_service_client.return_value.custom_job_path = mock.MagicMock( + return_value=test_custom_job_name + ) + await test_async_hook.get_custom_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_PIPELINE_JOB_ID, + ) + mock_get_job_service_client.assert_awaited_once_with(region=TEST_REGION) + mock_get_job_service_client.return_value.get_custom_job.assert_awaited_once_with( + request={"name": test_custom_job_name}, + retry=DEFAULT, + timeout=DEFAULT, + metadata=(), + ) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLED, + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_PAUSED, + PipelineState.PIPELINE_STATE_SUCCEEDED, + ], + ) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_wait_for_training_pipeline_returns_pipeline_if_in_complete_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + pipeline_state_value, + test_async_hook, + test_training_pipeline_name, + ): + expected_obj = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_get_training_pipeline.return_value = expected_obj + actual_obj = await test_async_hook.wait_for_training_pipeline( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + pipeline_id=TEST_PIPELINE_JOB_ID, + ) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_REGION) + assert actual_obj == expected_obj + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "job_state_value", + [ + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_PAUSED, + JobState.JOB_STATE_SUCCEEDED, + ], + ) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_custom_job")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_job_service_client")) + async def test_wait_for_custom_job_returns_job_if_in_complete_state( + self, + mock_get_job_service_client, + mock_get_custom_job, + job_state_value, + test_async_hook, + test_custom_job_name, + ): + expected_obj = types.CustomJob( + state=job_state_value, + name=test_custom_job_name, + ) + mock_get_custom_job.return_value = expected_obj + actual_obj = await test_async_hook.wait_for_custom_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_PIPELINE_JOB_ID, + ) + mock_get_job_service_client.assert_awaited_once_with(region=TEST_REGION) + assert actual_obj == expected_obj + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLING, + PipelineState.PIPELINE_STATE_PENDING, + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_UNSPECIFIED, + ], + ) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_wait_for_training_pipeline_loop_is_still_running_if_in_incomplete_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + pipeline_state_value, + test_async_hook, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline(state=pipeline_state_value) + task = asyncio.create_task( + test_async_hook.wait_for_training_pipeline( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + pipeline_id=TEST_PIPELINE_JOB_ID, + ) + ) + await asyncio.sleep(0.5) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_REGION) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "job_state_value", + [ + JobState.JOB_STATE_CANCELLING, + JobState.JOB_STATE_PENDING, + JobState.JOB_STATE_QUEUED, + JobState.JOB_STATE_RUNNING, + JobState.JOB_STATE_UNSPECIFIED, + ], + ) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_custom_job")) + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_job_service_client")) + async def test_wait_for_custom_job_loop_is_still_running_if_in_incomplete_state( + self, + mock_get_job_service_client, + mock_get_custom_job, + job_state_value, + test_async_hook, + ): + mock_get_custom_job.return_value = types.CustomJob(state=job_state_value) + task = asyncio.create_task( + test_async_hook.wait_for_custom_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_PIPELINE_JOB_ID, + ) + ) + await asyncio.sleep(0.5) + mock_get_job_service_client.assert_awaited_once_with(region=TEST_REGION) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_training_pipeline")) + async def test_wait_for_training_pipeline_raises_exception( + self, mock_get_training_pipeline, test_async_hook + ): + mock_get_training_pipeline.side_effect = mock.AsyncMock(side_effect=Exception()) + with pytest.raises(AirflowException): + await test_async_hook.wait_for_training_pipeline( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + pipeline_id=TEST_PIPELINE_JOB_ID, + ) + + @pytest.mark.asyncio + @mock.patch(CUSTOM_JOB_STRING.format("CustomJobAsyncHook.get_custom_job")) + async def test_wait_for_custom_job_raises_exception(self, mock_get_custom_job, test_async_hook): + mock_get_custom_job.side_effect = mock.AsyncMock(side_effect=Exception()) + with pytest.raises(AirflowException): + await test_async_hook.wait_for_custom_job( + project_id=TEST_PROJECT_ID, + location=TEST_REGION, + job_id=TEST_PIPELINE_JOB_ID, + ) diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 8ad24c6ccf346..bc74d79396d52 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -84,11 +84,18 @@ ListPipelineJobOperator, RunPipelineJobOperator, ) -from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger +from airflow.providers.google.cloud.triggers.vertex_ai import ( + CustomContainerTrainingJobTrigger, + CustomPythonPackageTrainingJobTrigger, + CustomTrainingJobTrigger, + RunPipelineJobTrigger, +) from airflow.utils import timezone VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" VERTEX_AI_LINKS_PATH = "airflow.providers.google.cloud.links.vertex_ai.{}" +VERTEX_AI_TRIGGER_PATH = "airflow.providers.google.cloud.triggers.vertex_ai.{}" +VERTEX_AI_HOOK_PATH = "airflow.providers.google.cloud.hooks.vertex_ai." TIMEOUT = 120 RETRY = mock.MagicMock(Retry) METADATA = [("key", "value")] @@ -280,6 +287,96 @@ def test_execute(self, mock_hook, mock_dataset): model_version_description=None, ) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook")) + def test_execute_enters_deferred_state(self, mock_hook): + task = CreateCustomContainerTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(context={"ti": mock.MagicMock()}) + assert isinstance( + exc.value.trigger, CustomContainerTrainingJobTrigger + ), "Trigger is not a CustomContainerTrainingJobTrigger" + + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.xcom_push")) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomContainerTrainingJobOperator.hook")) + def test_execute_complete_success(self, mock_hook, mock_xcom_push): + task = CreateCustomContainerTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + expected_result = {} + mock_hook.return_value.exists.return_value = False + mock_xcom_push.return_value = None + actual_result = task.execute_complete( + context=None, event={"status": "success", "message": "", "job": {}} + ) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self): + task = CreateCustomContainerTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + args=ARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + command=COMMAND_2, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + with pytest.raises(AirflowException): + task.execute_complete(context=None, event={"status": "error", "message": "test message"}) + class TestVertexAICreateCustomPythonPackageTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset")) @@ -372,6 +469,99 @@ def test_execute(self, mock_hook, mock_dataset): sync=True, ) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook")) + def test_execute_enters_deferred_state(self, mock_hook): + task = CreateCustomPythonPackageTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(context={"ti": mock.MagicMock()}) + assert isinstance( + exc.value.trigger, CustomPythonPackageTrainingJobTrigger + ), "Trigger is not a CustomPythonPackageTrainingJobTrigger" + + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.xcom_push")) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomPythonPackageTrainingJobOperator.hook")) + def test_execute_complete_success(self, mock_hook, mock_xcom_push): + task = CreateCustomPythonPackageTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + expected_result = {} + mock_hook.return_value.exists.return_value = False + mock_xcom_push.return_value = None + actual_result = task.execute_complete( + context=None, event={"status": "success", "message": "", "job": {}} + ) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self): + task = CreateCustomPythonPackageTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + args=ARGS, + model_serving_container_image_uri=CONTAINER_URI, + model_display_name=DISPLAY_NAME_2, + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + with pytest.raises(AirflowException): + task.execute_complete(context=None, event={"status": "error", "message": "test message"}) + class TestVertexAICreateCustomTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.Dataset")) @@ -457,6 +647,78 @@ def test_execute(self, mock_hook, mock_dataset): model_version_description=None, ) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook")) + def test_execute_enters_deferred_state(self, mock_hook): + task = CreateCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(context={"ti": mock.MagicMock()}) + assert isinstance( + exc.value.trigger, CustomTrainingJobTrigger + ), "Trigger is not a CustomTrainingJobTrigger" + + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.xcom_push")) + @mock.patch(VERTEX_AI_PATH.format("custom_job.CreateCustomTrainingJobOperator.hook")) + def test_execute_complete_success(self, mock_hook, mock_xcom_push): + task = CreateCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + expected_result = {} + mock_hook.return_value.exists.return_value = False + mock_xcom_push.return_value = None + actual_result = task.execute_complete( + context=None, event={"status": "success", "message": "", "job": {}} + ) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self): + task = CreateCustomTrainingJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + staging_bucket=STAGING_BUCKET, + display_name=DISPLAY_NAME, + script_path=PYTHON_PACKAGE, + args=PYTHON_PACKAGE_CMDARGS, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=CONTAINER_URI, + requirements=[], + replica_count=1, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + deferrable=True, + ) + with pytest.raises(AirflowException): + task.execute_complete(context=None, event={"status": "error", "message": "test message"}) + class TestVertexAIDeleteCustomTrainingJobOperator: @mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook")) diff --git a/tests/providers/google/cloud/triggers/test_vertex_ai.py b/tests/providers/google/cloud/triggers/test_vertex_ai.py index 6a5ed10acfec2..ad779d7a60c04 100644 --- a/tests/providers/google/cloud/triggers/test_vertex_ai.py +++ b/tests/providers/google/cloud/triggers/test_vertex_ai.py @@ -30,11 +30,15 @@ ) from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobAsyncHook from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobAsyncHook from airflow.providers.google.cloud.triggers.vertex_ai import ( BaseVertexAIJobTrigger, CreateBatchPredictionJobTrigger, CreateHyperparameterTuningJobTrigger, + CustomContainerTrainingJobTrigger, + CustomPythonPackageTrainingJobTrigger, + CustomTrainingJobTrigger, RunPipelineJobTrigger, ) from airflow.triggers.base import TriggerEvent @@ -67,6 +71,50 @@ def run_pipeline_job_trigger(): ) +@pytest.fixture +def custom_container_training_job_trigger(): + return CustomContainerTrainingJobTrigger( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + +@pytest.fixture +def custom_python_package_training_job_trigger(): + return CustomPythonPackageTrainingJobTrigger( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + +@pytest.fixture +def custom_training_job_trigger(): + return CustomTrainingJobTrigger( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + +@pytest.fixture +def custom_job_async_hook(): + return CustomJobAsyncHook( + gcp_conn_id=TEST_CONN_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + @pytest.fixture def pipeline_job_async_hook(): return PipelineJobAsyncHook( @@ -91,6 +139,15 @@ def test_pipeline_job_name(pipeline_service_async_client): ) +@pytest.fixture +def test_training_pipeline_name(pipeline_service_async_client): + return pipeline_service_async_client.training_pipeline_path( + project=TEST_PROJECT_ID, + location=TEST_LOCATION, + training_pipeline=TEST_HPT_JOB_ID, + ) + + class TestBaseVertexAIJobTrigger: def setup_method(self, method): self.trigger = BaseVertexAIJobTrigger( @@ -455,9 +512,6 @@ async def test_run_test_run_loop_is_still_running_if_pipeline_is_running( @pytest.mark.asyncio @mock.patch(VERTEX_AI_TRIGGER_PATH.format("PipelineJobAsyncHook.get_pipeline_job")) async def test_run_raises_exception(self, mock_get_pipeline_job, run_pipeline_job_trigger): - """ - Tests the DataflowJobAutoScalingEventTrigger does fire if there is an exception. - """ mock_get_pipeline_job.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) expected_event = TriggerEvent( { @@ -478,3 +532,463 @@ async def test_wait_job(self, mock_wait_for_pipeline_job, run_pipeline_job_trigg job_id=run_pipeline_job_trigger.job_id, poll_interval=run_pipeline_job_trigger.poll_interval, ) + + +class TestCustomTrainingJobTrigger: + def test_serialize(self, custom_training_job_trigger): + actual_data = custom_training_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.vertex_ai.CustomTrainingJobTrigger", + { + "conn_id": TEST_CONN_ID, + "project_id": TEST_PROJECT_ID, + "location": TEST_LOCATION, + "job_id": TEST_HPT_JOB_ID, + "poll_interval": TEST_POLL_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + }, + ) + actual_data == expected_data + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_SUCCEEDED, + PipelineState.PIPELINE_STATE_PAUSED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_success_event_on_successful_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + custom_training_job_trigger, + pipeline_state_value, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "success", + "message": ( + f"{custom_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_CANCELLED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_error_event_on_failed_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + pipeline_state_value, + custom_training_job_trigger, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "error", + "message": ( + f"{custom_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLING, + PipelineState.PIPELINE_STATE_PENDING, + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_UNSPECIFIED, + ], + ) + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_test_run_loop_is_still_running_if_pipeline_is_running( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + pipeline_state_value, + custom_training_job_trigger, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline(state=pipeline_state_value) + task = asyncio.create_task(custom_training_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_raises_exception( + self, mock_get_pipeline_service_client, mock_get_training_pipeline, custom_training_job_trigger + ): + mock_get_training_pipeline.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": "Test exception", + } + ) + actual_event = await custom_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert expected_event == actual_event + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.wait_for_training_pipeline")) + async def test_wait_training_pipeline(self, mock_wait_for_training_pipeline, custom_training_job_trigger): + await custom_training_job_trigger._wait_job() + mock_wait_for_training_pipeline.assert_awaited_once_with( + project_id=custom_training_job_trigger.project_id, + location=custom_training_job_trigger.location, + pipeline_id=custom_training_job_trigger.job_id, + poll_interval=custom_training_job_trigger.poll_interval, + ) + + +class TestCustomContainerTrainingJobTrigger: + def test_serialize(self, custom_container_training_job_trigger): + actual_data = custom_container_training_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.vertex_ai.CustomContainerTrainingJobTrigger", + { + "conn_id": TEST_CONN_ID, + "project_id": TEST_PROJECT_ID, + "location": TEST_LOCATION, + "job_id": TEST_HPT_JOB_ID, + "poll_interval": TEST_POLL_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + }, + ) + actual_data == expected_data + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_SUCCEEDED, + PipelineState.PIPELINE_STATE_PAUSED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_success_event_on_successful_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + custom_container_training_job_trigger, + pipeline_state_value, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "success", + "message": ( + f"{custom_container_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_container_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_CANCELLED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_error_event_on_failed_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + pipeline_state_value, + custom_container_training_job_trigger, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "error", + "message": ( + f"{custom_container_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_container_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLING, + PipelineState.PIPELINE_STATE_PENDING, + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_UNSPECIFIED, + ], + ) + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_test_run_loop_is_still_running_if_pipeline_is_running( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + pipeline_state_value, + custom_container_training_job_trigger, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline(state=pipeline_state_value) + task = asyncio.create_task(custom_container_training_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_raises_exception( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + custom_container_training_job_trigger, + ): + mock_get_training_pipeline.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": "Test exception", + } + ) + actual_event = await custom_container_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert expected_event == actual_event + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.wait_for_training_pipeline")) + async def test_wait_training_pipeline( + self, mock_wait_for_training_pipeline, custom_container_training_job_trigger + ): + await custom_container_training_job_trigger._wait_job() + mock_wait_for_training_pipeline.assert_awaited_once_with( + project_id=custom_container_training_job_trigger.project_id, + location=custom_container_training_job_trigger.location, + pipeline_id=custom_container_training_job_trigger.job_id, + poll_interval=custom_container_training_job_trigger.poll_interval, + ) + + +class TestCustomPythonPackageTrainingJobTrigger: + def test_serialize(self, custom_python_package_training_job_trigger): + actual_data = custom_python_package_training_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.vertex_ai.CustomPythonPackageTrainingJobTrigger", + { + "conn_id": TEST_CONN_ID, + "project_id": TEST_PROJECT_ID, + "location": TEST_LOCATION, + "job_id": TEST_HPT_JOB_ID, + "poll_interval": TEST_POLL_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + }, + ) + actual_data == expected_data + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_SUCCEEDED, + PipelineState.PIPELINE_STATE_PAUSED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_success_event_on_successful_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + custom_python_package_training_job_trigger, + pipeline_state_value, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "success", + "message": ( + f"{custom_python_package_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_python_package_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_CANCELLED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.TrainingPipeline.to_dict") + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_yields_error_event_on_failed_pipeline_state( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + mock_pipeline_job_dict, + pipeline_state_value, + custom_python_package_training_job_trigger, + test_training_pipeline_name, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline( + state=pipeline_state_value, + name=test_training_pipeline_name, + ) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "error", + "message": ( + f"{custom_python_package_training_job_trigger.job_type_verbose_name} {test_training_pipeline_name} " + f"completed with status {pipeline_state_value.name}" + ), + "job": {}, + } + ) + actual_event = await custom_python_package_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLING, + PipelineState.PIPELINE_STATE_PENDING, + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_UNSPECIFIED, + ], + ) + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_test_run_loop_is_still_running_if_pipeline_is_running( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + pipeline_state_value, + custom_python_package_training_job_trigger, + ): + mock_get_training_pipeline.return_value = types.TrainingPipeline(state=pipeline_state_value) + task = asyncio.create_task(custom_python_package_training_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_training_pipeline")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.get_pipeline_service_client")) + async def test_run_raises_exception( + self, + mock_get_pipeline_service_client, + mock_get_training_pipeline, + custom_python_package_training_job_trigger, + ): + mock_get_training_pipeline.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": "Test exception", + } + ) + actual_event = await custom_python_package_training_job_trigger.run().asend(None) + mock_get_pipeline_service_client.assert_awaited_once_with(region=TEST_LOCATION) + assert expected_event == actual_event + + @pytest.mark.asyncio + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("CustomJobAsyncHook.wait_for_training_pipeline")) + async def test_wait_training_pipeline( + self, mock_wait_for_training_pipeline, custom_python_package_training_job_trigger + ): + await custom_python_package_training_job_trigger._wait_job() + mock_wait_for_training_pipeline.assert_awaited_once_with( + project_id=custom_python_package_training_job_trigger.project_id, + location=custom_python_package_training_job_trigger.location, + pipeline_id=custom_python_package_training_job_trigger.job_id, + poll_interval=custom_python_package_training_job_trigger.poll_interval, + ) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py index 72ae090e6e708..df1a3e5306191 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_container.py @@ -136,6 +136,30 @@ def TABULAR_DATASET(bucket_name): ) # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator] + # [START how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable] + create_custom_container_training_job_deferrable = CreateCustomContainerTrainingJobOperator( + task_id="custom_container_task_deferrable", + staging_bucket=f"gs://{CUSTOM_CONTAINER_GCS_BUCKET_NAME}", + display_name=f"{CONTAINER_DISPLAY_NAME}_DEF", + container_uri=CUSTOM_CONTAINER_URI, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=tabular_dataset_id, + command=["python3", "task.py"], + model_display_name=f"{MODEL_DISPLAY_NAME}_DEF", + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_create_custom_container_training_job_operator_deferrable] + delete_custom_training_job = DeleteCustomTrainingJobOperator( task_id="delete_custom_training_job", training_pipeline_id="{{ task_instance.xcom_pull(task_ids='custom_container_task', " @@ -147,6 +171,17 @@ def TABULAR_DATASET(bucket_name): trigger_rule=TriggerRule.ALL_DONE, ) + delete_custom_training_job_deferrable = DeleteCustomTrainingJobOperator( + task_id="delete_custom_training_job_deferrable", + training_pipeline_id="{{ task_instance.xcom_pull(task_ids='custom_container_task_deferrable', " + "key='training_id') }}", + custom_job_id="{{ task_instance.xcom_pull(task_ids='custom_container_task_deferrable', " + "key='custom_job_id') }}", + region=REGION, + project_id=PROJECT_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + delete_tabular_dataset = DeleteDatasetOperator( task_id="delete_tabular_dataset", dataset_id=tabular_dataset_id, @@ -166,9 +201,10 @@ def TABULAR_DATASET(bucket_name): >> move_data_files >> create_tabular_dataset # TEST BODY - >> create_custom_container_training_job + >> [create_custom_container_training_job, create_custom_container_training_job_deferrable] # TEST TEARDOWN >> delete_custom_training_job + >> delete_custom_training_job_deferrable >> delete_tabular_dataset >> delete_bucket ) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py index fa876f7383688..702e6a6c51f16 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job.py @@ -30,6 +30,7 @@ from google.protobuf.json_format import ParseDict from google.protobuf.struct_pb2 import Value +from airflow.models.baseoperator import chain from airflow.models.dag import DAG from airflow.providers.google.cloud.operators.gcs import ( GCSCreateBucketOperator, @@ -129,13 +130,32 @@ def TABULAR_DATASET(bucket_name): dataset_id=tabular_dataset_id, replica_count=REPLICA_COUNT, model_display_name=MODEL_DISPLAY_NAME, - sync=False, region=REGION, project_id=PROJECT_ID, ) model_id_v1 = create_custom_training_job.output["model_id"] # [END how_to_cloud_vertex_ai_create_custom_training_job_operator] + # [START how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable] + create_custom_training_job_deferrable = CreateCustomTrainingJobOperator( + task_id="custom_task_deferrable", + staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}", + display_name=f"{CUSTOM_DISPLAY_NAME}_DEF", + script_path=LOCAL_TRAINING_SCRIPT_PATH, + container_uri=CONTAINER_URI, + requirements=["gcsfs==0.7.1"], + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=tabular_dataset_id, + replica_count=REPLICA_COUNT, + model_display_name=f"{MODEL_DISPLAY_NAME}_DEF", + region=REGION, + project_id=PROJECT_ID, + deferrable=True, + ) + model_id_v1_deferrable = create_custom_training_job.output["model_id"] + # [END how_to_cloud_vertex_ai_create_custom_training_job_operator_deferrable] + # [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator] create_custom_training_job_v2 = CreateCustomTrainingJobOperator( task_id="custom_task_v2", @@ -156,6 +176,27 @@ def TABULAR_DATASET(bucket_name): ) # [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator] + # [START how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable] + create_custom_training_job_v2_deferrable = CreateCustomTrainingJobOperator( + task_id="custom_task_v2_deferrable", + staging_bucket=f"gs://{CUSTOM_GCS_BUCKET_NAME}", + display_name=f"{CUSTOM_DISPLAY_NAME}_DEF", + script_path=LOCAL_TRAINING_SCRIPT_PATH, + container_uri=CONTAINER_URI, + requirements=["gcsfs==0.7.1"], + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + parent_model=model_id_v1, + # run params + dataset_id=tabular_dataset_id, + replica_count=REPLICA_COUNT, + model_display_name=f"{MODEL_DISPLAY_NAME}_DEF", + sync=False, + region=REGION, + project_id=PROJECT_ID, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_create_custom_training_job_v2_operator_deferrable] + # [START how_to_cloud_vertex_ai_delete_custom_training_job_operator] delete_custom_training_job = DeleteCustomTrainingJobOperator( task_id="delete_custom_training_job", @@ -181,18 +222,20 @@ def TABULAR_DATASET(bucket_name): ) ( - # TEST SETUP - create_bucket - >> move_data_files - >> download_training_script_file - >> create_tabular_dataset - # TEST BODY - >> create_custom_training_job - >> create_custom_training_job_v2 - # TEST TEARDOWN - >> delete_custom_training_job - >> delete_tabular_dataset - >> delete_bucket + chain( + # TEST SETUP + create_bucket, + move_data_files, + download_training_script_file, + create_tabular_dataset, + # TEST BODY + [create_custom_training_job, create_custom_training_job_deferrable], + [create_custom_training_job_v2, create_custom_training_job_v2_deferrable], + # TEST TEARDOWN + delete_custom_training_job, + delete_tabular_dataset, + delete_bucket, + ) ) diff --git a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py index 867154a03d4c0..5aad847f84dd9 100644 --- a/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py +++ b/tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_custom_job_python_package.py @@ -139,6 +139,31 @@ def TABULAR_DATASET(bucket_name): ) # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator] + # [START how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable] + create_custom_python_package_training_job_deferrable = CreateCustomPythonPackageTrainingJobOperator( + task_id="python_package_task_deferrable", + staging_bucket=f"gs://{CUSTOM_PYTHON_GCS_BUCKET_NAME}", + display_name=f"{PACKAGE_DISPLAY_NAME}_DEF", + python_package_gcs_uri=PYTHON_PACKAGE_GCS_URI, + python_module_name=PYTHON_MODULE_NAME, + container_uri=CONTAINER_URI, + model_serving_container_image_uri=MODEL_SERVING_CONTAINER_URI, + # run params + dataset_id=tabular_dataset_id, + model_display_name=f"{MODEL_DISPLAY_NAME}_DEF", + replica_count=REPLICA_COUNT, + machine_type=MACHINE_TYPE, + accelerator_type=ACCELERATOR_TYPE, + accelerator_count=ACCELERATOR_COUNT, + training_fraction_split=TRAINING_FRACTION_SPLIT, + validation_fraction_split=VALIDATION_FRACTION_SPLIT, + test_fraction_split=TEST_FRACTION_SPLIT, + region=REGION, + project_id=PROJECT_ID, + deferrable=True, + ) + # [END how_to_cloud_vertex_ai_create_custom_python_package_training_job_operator_deferrable] + delete_custom_training_job = DeleteCustomTrainingJobOperator( task_id="delete_custom_training_job", training_pipeline_id="{{ task_instance.xcom_pull(task_ids='python_package_task', " @@ -149,6 +174,16 @@ def TABULAR_DATASET(bucket_name): trigger_rule=TriggerRule.ALL_DONE, ) + delete_custom_training_job_deferrable = DeleteCustomTrainingJobOperator( + task_id="delete_custom_training_job_deferrable", + training_pipeline_id="{{ task_instance.xcom_pull(task_ids='python_package_task_deferrable', " + "key='training_id') }}", + custom_job_id="{{ task_instance.xcom_pull(task_ids='python_package_task_deferrable', key='custom_job_id') }}", + region=REGION, + project_id=PROJECT_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + delete_tabular_dataset = DeleteDatasetOperator( task_id="delete_tabular_dataset", dataset_id=tabular_dataset_id, @@ -168,9 +203,10 @@ def TABULAR_DATASET(bucket_name): >> move_data_files >> create_tabular_dataset # TEST BODY - >> create_custom_python_package_training_job + >> [create_custom_python_package_training_job, create_custom_python_package_training_job_deferrable] # TEST TEARDOWN >> delete_custom_training_job + >> delete_custom_training_job_deferrable >> delete_tabular_dataset >> delete_bucket )