diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py index cd4b49b961cf7..f7148b9336d06 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/hyperparameter_tuning_job.py @@ -39,8 +39,7 @@ if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry - from google.api_core.retry_async import AsyncRetry + from google.api_core.retry import AsyncRetry, Retry from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager diff --git a/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py b/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py index c1be11b6820bb..de3434da59400 100644 --- a/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py +++ b/airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py @@ -28,15 +28,21 @@ from google.api_core.client_options import ClientOptions from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault from google.cloud.aiplatform import PipelineJob -from google.cloud.aiplatform_v1 import PipelineServiceClient +from google.cloud.aiplatform_v1 import ( + GetPipelineJobRequest, + PipelineServiceAsyncClient, + PipelineServiceClient, + types, +) from airflow.exceptions import AirflowException from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook +from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: from google.api_core.operation import Operation - from google.api_core.retry import Retry + from google.api_core.retry import AsyncRetry, Retry + from google.auth.credentials import Credentials from google.cloud.aiplatform.metadata import experiment_resources from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListPipelineJobsPager @@ -101,11 +107,6 @@ def get_pipeline_job_object( failure_policy=failure_policy, ) - @staticmethod - def extract_pipeline_job_id(obj: dict) -> str: - """Return unique id of the pipeline_job.""" - return obj["name"].rpartition("/")[-1] - def wait_for_operation(self, operation: Operation, timeout: float | None = None): """Wait for long-lasting operation to complete.""" try: @@ -129,7 +130,7 @@ def create_pipeline_job( retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), - ) -> PipelineJob: + ) -> types.PipelineJob: """ Create a PipelineJob. A PipelineJob will run immediately when created. @@ -182,7 +183,7 @@ def run_pipeline_job( # END: run param ) -> PipelineJob: """ - Run PipelineJob and monitor the job until completion. + Create and run a PipelineJob until its completion. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. @@ -243,7 +244,103 @@ def run_pipeline_job( location=region, failure_policy=failure_policy, ) + self._pipeline_job.submit( + service_account=service_account, + network=network, + create_request_timeout=create_request_timeout, + experiment=experiment, + ) + self._pipeline_job.wait() + + return self._pipeline_job + @GoogleBaseHook.fallback_to_default_project_id + def submit_pipeline_job( + self, + project_id: str, + region: str, + display_name: str, + template_path: str, + job_id: str | None = None, + pipeline_root: str | None = None, + parameter_values: dict[str, Any] | None = None, + input_artifacts: dict[str, str] | None = None, + enable_caching: bool | None = None, + encryption_spec_key_name: str | None = None, + labels: dict[str, str] | None = None, + failure_policy: str | None = None, + # START: run param + service_account: str | None = None, + network: str | None = None, + create_request_timeout: float | None = None, + experiment: str | experiment_resources.Experiment | None = None, + # END: run param + ) -> PipelineJob: + """ + Create and start a PipelineJob run. + + For more info about the client method please see: + https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob#google_cloud_aiplatform_PipelineJob_submit + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param region: Required. The ID of the Google Cloud region that the service belongs to. + :param display_name: Required. The user-defined name of this Pipeline. + :param template_path: Required. The path of PipelineJob or PipelineSpec JSON or YAML file. It can be + a local path, a Google Cloud Storage URI (e.g. "gs://project.name"), an Artifact Registry URI + (e.g. "https://us-central1-kfp.pkg.dev/proj/repo/pack/latest"), or an HTTPS URI. + :param job_id: Optional. The unique ID of the job run. If not specified, pipeline name + timestamp + will be used. + :param pipeline_root: Optional. The root of the pipeline outputs. If not set, the staging bucket set + in aiplatform.init will be used. If that's not set a pipeline-specific artifacts bucket will be + used. + :param parameter_values: Optional. The mapping from runtime parameter names to its values that + control the pipeline run. + :param input_artifacts: Optional. The mapping from the runtime parameter name for this artifact to + its resource id. For example: "vertex_model":"456". Note: full resource name + ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. + :param enable_caching: Optional. Whether to turn on caching for the run. + If this is not set, defaults to the compile time settings, which are True for all tasks by + default, while users may specify different caching options for individual tasks. + If this is set, the setting applies to all tasks in the pipeline. Overrides the compile time + settings. + :param encryption_spec_key_name: Optional. The Cloud KMS resource identifier of the customer managed + encryption key used to protect the job. Has the form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute resource is created. If this is set, + then all resources created by the PipelineJob will be encrypted with the provided encryption key. + Overrides encryption_spec_key_name set in aiplatform.init. + :param labels: Optional. The user defined metadata to organize PipelineJob. + :param failure_policy: Optional. The failure policy - "slow" or "fast". Currently, the default of a + pipeline is that the pipeline will continue to run until no more tasks can be executed, also + known as PIPELINE_FAILURE_POLICY_FAIL_SLOW (corresponds to "slow"). However, if a pipeline is set + to PIPELINE_FAILURE_POLICY_FAIL_FAST (corresponds to "fast"), it will stop scheduling any new + tasks when a task has failed. Any scheduled tasks will continue to completion. + :param service_account: Optional. Specifies the service account for workload run-as account. Users + submitting jobs must have act-as permission on this run-as account. + :param network: Optional. The full name of the Compute Engine network to which the job should be + peered. For example, projects/12345/global/networks/myVPC. + Private services access must already be configured for the network. If left unspecified, the + network set in aiplatform.init will be used. Otherwise, the job is not peered with any network. + :param create_request_timeout: Optional. The timeout for the create request in seconds. + :param experiment: Optional. The Vertex AI experiment name or instance to associate to this PipelineJob. + Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as metrics + to the current Experiment Run. Pipeline parameters will be associated as parameters to + the current Experiment Run. + """ + self._pipeline_job = self.get_pipeline_job_object( + display_name=display_name, + template_path=template_path, + job_id=job_id, + pipeline_root=pipeline_root, + parameter_values=parameter_values, + input_artifacts=input_artifacts, + enable_caching=enable_caching, + encryption_spec_key_name=encryption_spec_key_name, + labels=labels, + project=project_id, + location=region, + failure_policy=failure_policy, + ) self._pipeline_job.submit( service_account=service_account, network=network, @@ -251,7 +348,6 @@ def run_pipeline_job( experiment=experiment, ) - self._pipeline_job.wait() return self._pipeline_job @GoogleBaseHook.fallback_to_default_project_id @@ -263,7 +359,7 @@ def get_pipeline_job( retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, metadata: Sequence[tuple[str, str]] = (), - ) -> PipelineJob: + ) -> types.PipelineJob: """ Get a PipelineJob. @@ -407,3 +503,86 @@ def delete_pipeline_job( metadata=metadata, ) return result + + @staticmethod + def extract_pipeline_job_id(obj: dict) -> str: + """Return unique id of a pipeline job from its name.""" + return obj["name"].rpartition("/")[-1] + + +class PipelineJobAsyncHook(GoogleBaseAsyncHook): + """Asynchronous hook for Google Cloud Vertex AI Pipeline Job APIs.""" + + sync_hook_class = PipelineJobHook + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + **kwargs, + ) -> None: + super().__init__( + gcp_conn_id=gcp_conn_id, + impersonation_chain=impersonation_chain, + ) + + async def get_credentials(self) -> Credentials: + credentials = (await self.get_sync_hook()).get_credentials() + return credentials + + async def get_project_id(self) -> str: + sync_hook = await self.get_sync_hook() + return sync_hook.project_id + + async def get_project_location(self) -> str: + sync_hook = await self.get_sync_hook() + return sync_hook.location + + async def get_pipeline_service_client( + self, + region: str | None = None, + ) -> PipelineServiceAsyncClient: + """Return PipelineServiceAsyncClient object.""" + if region and region != "global": + client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443") + else: + client_options = ClientOptions() + return PipelineServiceAsyncClient( + credentials=await self.get_credentials(), + client_info=CLIENT_INFO, + client_options=client_options, + ) + + async def get_pipeline_job( + self, + project_id: str, + location: str, + job_id: str, + retry: AsyncRetry | _MethodDefault = DEFAULT, + timeout: float | _MethodDefault | None = DEFAULT, + metadata: Sequence[tuple[str, str]] = (), + ) -> types.PipelineJob: + """ + Get a PipelineJob message from PipelineServiceAsyncClient. + + :param project_id: Required. The ID of the Google Cloud project that the service belongs to. + :param location: Required. The ID of the Google Cloud region that the service belongs to. + :param job_id: Required. The ID of the PipelineJob resource. + :param retry: Designation of what errors, if any, should be retried. + :param timeout: The timeout for this request. + :param metadata: Strings which should be sent along with the request as metadata. + """ + client = await self.get_pipeline_service_client(region=location) + request = self.get_pipeline_job_request(project=project_id, location=location, job=job_id) + response: types.PipelineJob = await client.get_pipeline_job( + request=request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + return response + + @staticmethod + def get_pipeline_job_request(project: str, location: str, job: str) -> GetPipelineJobRequest: + name = f"projects/{project}/locations/{location}/pipelineJobs/{job}" + return GetPipelineJobRequest(name=name) diff --git a/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py index 1510f8fcfac19..e6d04af778983 100644 --- a/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py +++ b/airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py @@ -18,21 +18,26 @@ """This module contains Google Vertex AI operators.""" from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.aiplatform_v1.types import PipelineJob +from google.cloud.aiplatform_v1 import types +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobHook from airflow.providers.google.cloud.links.vertex_ai import ( VertexAIPipelineJobLink, VertexAIPipelineJobListLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator +from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger if TYPE_CHECKING: from google.api_core.retry import Retry + from google.cloud.aiplatform import PipelineJob from google.cloud.aiplatform.metadata import experiment_resources from airflow.utils.context import Context @@ -40,7 +45,7 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator): """ - Run Pipeline job. + Create and run a Pipeline job. :param project_id: Required. The ID of the Google Cloud project that the service belongs to. :param region: Required. The ID of the Google Cloud region that the service belongs to. @@ -82,11 +87,13 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator): Private services access must already be configured for the network. If left unspecified, the network set in aiplatform.init will be used. Otherwise, the job is not peered with any network. :param create_request_timeout: Optional. The timeout for the create request in seconds. - :param experiment: Optional. The Vertex AI experiment name or instance to associate to this - PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as - metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to + :param experiment: Optional. The Vertex AI experiment name or instance to associate to this PipelineJob. + Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as metrics + to the current Experiment Run. Pipeline parameters will be associated as parameters to the current Experiment Run. :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param sync: Whether to execute this method synchronously. If False, this method will unblock, and it + will be executed in a concurrent Future. The default is True. :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token of the last account in the list, which will be impersonated in the request. @@ -95,6 +102,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param deferrable: If True, run the task in the deferrable mode. + Note that it requires calling the operator with `sync=False` parameter. + :param poll_interval: Time (seconds) to wait between two consecutive calls to check the job. + The default is 300 seconds. """ template_fields = [ @@ -126,6 +137,8 @@ def __init__( experiment: str | experiment_resources.Experiment | None = None, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + poll_interval: int = 5 * 60, **kwargs, ) -> None: super().__init__(**kwargs) @@ -147,15 +160,12 @@ def __init__( self.experiment = experiment self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook: PipelineJobHook | None = None + self.deferrable = deferrable + self.poll_interval = poll_interval def execute(self, context: Context): self.log.info("Running Pipeline job") - self.hook = PipelineJobHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - result = self.hook.run_pipeline_job( + pipeline_job_obj: PipelineJob = self.hook.submit_pipeline_job( project_id=self.project_id, region=self.region, display_name=self.display_name, @@ -173,20 +183,48 @@ def execute(self, context: Context): create_request_timeout=self.create_request_timeout, experiment=self.experiment, ) - - pipeline_job = result.to_dict() - pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job) + if self.deferrable: + self.log.info("Pipeline job was created. Job id: %s", pipeline_job_obj.job_id) + self.defer( + trigger=RunPipelineJobTrigger( + conn_id=self.gcp_conn_id, + project_id=self.project_id, + location=pipeline_job_obj.location, + job_id=pipeline_job_obj.job_id, + poll_interval=self.poll_interval, + impersonation_chain=self.impersonation_chain, + ), + method_name="execute_complete", + ) + pipeline_job_obj.wait() + pipeline_job = pipeline_job_obj.to_dict() + pipeline_job_id = pipeline_job_obj.job_id self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id) - self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id) VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id) return pipeline_job + def execute_complete(self, context: Context, event: dict[str, Any]) -> None: + if event["status"] == "error": + raise AirflowException(event["message"]) + job_id = self.hook.extract_pipeline_job_id(event["job"]) + self.xcom_push(context, key="pipeline_job_id", value=job_id) + VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=job_id) + return event["job"] + def on_kill(self) -> None: """Act as a callback called when the operator is killed; cancel any running job.""" + self.log.warning("This is hook on kill: %s", self.hook) if self.hook: self.hook.cancel_pipeline_job() + @cached_property + def hook(self) -> PipelineJobHook: + return PipelineJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + class GetPipelineJobOperator(GoogleCloudBaseOperator): """ @@ -261,7 +299,7 @@ def execute(self, context: Context): context=context, task_instance=self, pipeline_id=self.pipeline_job_id ) self.log.info("Pipeline job was gotten.") - return PipelineJob.to_dict(result) + return types.PipelineJob.to_dict(result) except NotFound: self.log.info("The Pipeline job %s does not exist.", self.pipeline_job_id) @@ -389,7 +427,7 @@ def execute(self, context: Context): metadata=self.metadata, ) VertexAIPipelineJobListLink.persist(context=context, task_instance=self) - return [PipelineJob.to_dict(result) for result in results] + return [types.PipelineJob.to_dict(result) for result in results] class DeletePipelineJobOperator(GoogleCloudBaseOperator): diff --git a/airflow/providers/google/cloud/triggers/vertex_ai.py b/airflow/providers/google/cloud/triggers/vertex_ai.py index b4b121895406f..76e11dee5eb74 100644 --- a/airflow/providers/google/cloud/triggers/vertex_ai.py +++ b/airflow/providers/google/cloud/triggers/vertex_ai.py @@ -16,14 +16,16 @@ # under the License. from __future__ import annotations +import asyncio from typing import Any, AsyncIterator, Sequence -from google.cloud.aiplatform_v1 import HyperparameterTuningJob, JobState +from google.cloud.aiplatform_v1 import JobState, PipelineState, types from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.vertex_ai.hyperparameter_tuning_job import ( HyperparameterTuningJobAsyncHook, ) +from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobAsyncHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -89,7 +91,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: { "status": status, "message": message, - "job": HyperparameterTuningJob.to_dict(job), + "job": types.HyperparameterTuningJob.to_dict(job), } ) @@ -97,3 +99,86 @@ def _get_async_hook(self) -> HyperparameterTuningJobAsyncHook: return HyperparameterTuningJobAsyncHook( gcp_conn_id=self.conn_id, impersonation_chain=self.impersonation_chain ) + + +class RunPipelineJobTrigger(BaseTrigger): + """A trigger that makes async calls to Vertex AI to check the state of a running pipeline job.""" + + PIPELINE_ERROR_STATES = ( + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_CANCELLED, + ) + PIPELINE_COMPLETE_STATES = ( + PipelineState.PIPELINE_STATE_SUCCEEDED, + PipelineState.PIPELINE_STATE_PAUSED, + *PIPELINE_ERROR_STATES, + ) + + def __init__( + self, + conn_id: str, + project_id: str, + location: str, + job_id: str, + poll_interval: int, + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.conn_id = conn_id + self.project_id = project_id + self.location = location + self.job_id = job_id + self.poll_interval = poll_interval + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + "airflow.providers.google.cloud.triggers.vertex_ai.RunPipelineJobTrigger", + { + "conn_id": self.conn_id, + "project_id": self.project_id, + "location": self.location, + "job_id": self.job_id, + "poll_interval": self.poll_interval, + "impersonation_chain": self.impersonation_chain, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + hook: PipelineJobAsyncHook = await self._get_async_hook() + status = "error" + while True: + try: + pipeline_job_message: types.PipelineJob = await hook.get_pipeline_job( + project_id=self.project_id, + location=self.location, + job_id=self.job_id, + ) + pipeline_job_state: PipelineState = pipeline_job_message.state + if pipeline_job_state in self.PIPELINE_COMPLETE_STATES: + status = "success" if pipeline_job_state not in self.PIPELINE_ERROR_STATES else status + yield TriggerEvent( + { + "status": status, + "message": f"Pipeline job '{self.job_id}' has completed with status {pipeline_job_state.name}.", + "job": types.PipelineJob.to_dict(pipeline_job_message), + } + ) + return + self.log.info("Current pipeline job state: %s.", pipeline_job_state.name) + self.log.info("Sleeping for %s seconds...", self.poll_interval) + await asyncio.sleep(self.poll_interval) + except Exception as exc: + yield TriggerEvent( + { + "status": status, + "message": f"Exception occurred when trying to run pipeline job {self.job_id}: {str(exc)}", + "job": None, + } + ) + + async def _get_async_hook(self) -> PipelineJobAsyncHook: + return PipelineJobAsyncHook( + gcp_conn_id=self.conn_id, + impersonation_chain=self.impersonation_chain, + ) diff --git a/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py b/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py index e784dabb0dfd1..b26c09ef43917 100644 --- a/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py +++ b/tests/providers/google/cloud/hooks/vertex_ai/test_pipeline_job.py @@ -19,9 +19,12 @@ from unittest import mock +import pytest from google.api_core.gapic_v1.method import DEFAULT +from google.cloud.aiplatform_v1 import GetPipelineJobRequest from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import ( + PipelineJobAsyncHook, PipelineJobHook, ) from tests.providers.google.cloud.utils.base_gcp_mock import ( @@ -30,6 +33,10 @@ ) TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_IMPERSONATION_CHAIN = [ + "IMPERSONATE", + "THIS", +] TEST_REGION: str = "test-region" TEST_PROJECT_ID: str = "test-project-id" TEST_PIPELINE_JOB: dict = {} @@ -39,6 +46,14 @@ PIPELINE_JOB_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.{}" +@pytest.fixture +def test_async_hook(): + return PipelineJobAsyncHook( + gcp_conn_id=TEST_PROJECT_ID, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + class TestPipelineJobWithDefaultProjectIdHook: def setup_method(self): with mock.patch( @@ -217,3 +232,31 @@ def test_list_pipeline_jobs(self, mock_client) -> None: timeout=None, ) mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION) + + class TestPipelineJobAsyncHook: + @pytest.mark.asyncio + @mock.patch(PIPELINE_JOB_STRING.format("PipelineJobAsyncHook.get_pipeline_service_client")) + async def test_get_pipeline_job(self, mock_get_pipeline_service_client, test_async_hook): + await test_async_hook.get_pipeline_job( + project_id=TEST_PROJECT_ID, location=TEST_REGION, job_id=TEST_PIPELINE_JOB_ID + ) + mock_get_pipeline_service_client.assert_called_once_with(region=TEST_REGION) + mock_get_pipeline_service_client.return_value.get_pipeline_job.assert_called_once_with( + request=GetPipelineJobRequest( + name=f"projects/{TEST_PROJECT_ID}/locations/{TEST_REGION}/pipelineJobs/{TEST_PIPELINE_JOB_ID}", + ), + retry=DEFAULT, + timeout=DEFAULT, + metadata=(), + ) + + def test_get_pipeline_job_request(self, test_async_hook): + expected_request = GetPipelineJobRequest( + name=f"projects/{TEST_PROJECT_ID}/locations/{TEST_REGION}/pipelineJobs/{TEST_PIPELINE_JOB_ID}", + ) + actual_request = test_async_hook.get_pipeline_job_request( + project=TEST_PROJECT_ID, + location=TEST_REGION, + job=TEST_PIPELINE_JOB_ID, + ) + assert actual_request == expected_request diff --git a/tests/providers/google/cloud/operators/test_vertex_ai.py b/tests/providers/google/cloud/operators/test_vertex_ai.py index 686b7ff7c2320..88b44dbd73a56 100644 --- a/tests/providers/google/cloud/operators/test_vertex_ai.py +++ b/tests/providers/google/cloud/operators/test_vertex_ai.py @@ -23,7 +23,7 @@ from google.api_core.gapic_v1.method import DEFAULT from google.api_core.retry import Retry -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.google.cloud.operators.vertex_ai.auto_ml import ( CreateAutoMLForecastingTrainingJobOperator, CreateAutoMLImageTrainingJobOperator, @@ -84,6 +84,7 @@ ListPipelineJobOperator, RunPipelineJobOperator, ) +from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger VERTEX_AI_PATH = "airflow.providers.google.cloud.operators.vertex_ai.{}" TIMEOUT = 120 @@ -1842,9 +1843,9 @@ def test_execute(self, mock_hook, to_dict_mock): class TestVertexAIRunPipelineJobOperator: - @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJob.to_dict")) @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) - def test_execute(self, mock_hook, to_dict_mock): + @mock.patch("google.cloud.aiplatform_v1.types.PipelineJob.to_dict") + def test_execute(self, to_dict_mock, mock_hook): op = RunPipelineJobOperator( task_id=TASK_ID, gcp_conn_id=GCP_CONN_ID, @@ -1868,7 +1869,7 @@ def test_execute(self, mock_hook, to_dict_mock): ) op.execute(context={"ti": mock.MagicMock()}) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) - mock_hook.return_value.run_pipeline_job.assert_called_once_with( + mock_hook.return_value.submit_pipeline_job.assert_called_once_with( project_id=GCP_PROJECT, region=GCP_LOCATION, display_name=DISPLAY_NAME, @@ -1887,9 +1888,70 @@ def test_execute(self, mock_hook, to_dict_mock): experiment=None, ) + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute_enters_deferred_state(self, mock_hook): + task = RunPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + display_name=DISPLAY_NAME, + template_path=TEST_TEMPLATE_PATH, + job_id=TEST_PIPELINE_JOB_ID, + deferrable=True, + ) + mock_hook.return_value.exists.return_value = False + with pytest.raises(TaskDeferred) as exc: + task.execute(context={"ti": mock.MagicMock()}) + assert isinstance(exc.value.trigger, RunPipelineJobTrigger), "Trigger is not a RunPipelineJobTrigger" + + @mock.patch( + "airflow.providers.google.cloud.operators.vertex_ai.pipeline_job.RunPipelineJobOperator.xcom_push" + ) + @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) + def test_execute_complete_success(self, mock_hook, mock_xcom_push): + task = RunPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + display_name=DISPLAY_NAME, + template_path=TEST_TEMPLATE_PATH, + job_id=TEST_PIPELINE_JOB_ID, + deferrable=True, + ) + expected_pipeline_job = expected_result = { + "name": f"projects/{GCP_PROJECT}/locations/{GCP_LOCATION}/pipelineJobs/{TEST_PIPELINE_JOB_ID}", + } + mock_hook.return_value.exists.return_value = False + mock_xcom_push.return_value = None + actual_result = task.execute_complete( + context=None, event={"status": "success", "message": "", "job": expected_pipeline_job} + ) + assert actual_result == expected_result + + def test_execute_complete_error_status_raises_exception(self): + task = RunPipelineJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_LOCATION, + project_id=GCP_PROJECT, + display_name=DISPLAY_NAME, + template_path=TEST_TEMPLATE_PATH, + job_id=TEST_PIPELINE_JOB_ID, + deferrable=True, + ) + with pytest.raises(AirflowException): + task.execute_complete( + context=None, event={"status": "error", "message": "test message", "job": None} + ) + class TestVertexAIGetPipelineJobOperator: - @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJob.to_dict")) + @mock.patch("google.cloud.aiplatform_v1.types.PipelineJob.to_dict") @mock.patch(VERTEX_AI_PATH.format("pipeline_job.PipelineJobHook")) def test_execute(self, mock_hook, to_dict_mock): op = GetPipelineJobOperator( diff --git a/tests/providers/google/cloud/triggers/test_vertex_ai.py b/tests/providers/google/cloud/triggers/test_vertex_ai.py index 5d68e171cf45a..24911e9796dd5 100644 --- a/tests/providers/google/cloud/triggers/test_vertex_ai.py +++ b/tests/providers/google/cloud/triggers/test_vertex_ai.py @@ -16,14 +16,17 @@ # under the License. from __future__ import annotations +import asyncio from unittest import mock -from unittest.mock import patch import pytest -from google.cloud.aiplatform_v1 import JobState +from google.cloud.aiplatform_v1 import JobState, PipelineState, types from airflow.exceptions import AirflowException -from airflow.providers.google.cloud.triggers.vertex_ai import CreateHyperparameterTuningJobTrigger +from airflow.providers.google.cloud.triggers.vertex_ai import ( + CreateHyperparameterTuningJobTrigger, + RunPipelineJobTrigger, +) from airflow.triggers.base import TriggerEvent TEST_CONN_ID = "test_connection" @@ -50,6 +53,18 @@ def create_hyperparameter_tuning_job_trigger(): ) +@pytest.fixture +def run_pipeline_job_trigger(): + return RunPipelineJobTrigger( + conn_id=TEST_CONN_ID, + project_id=TEST_PROJECT_ID, + location=TEST_LOCATION, + job_id=TEST_HPT_JOB_ID, + poll_interval=TEST_POLL_INTERVAL, + impersonation_chain=TEST_IMPERSONATION_CHAIN, + ) + + class TestCreateHyperparameterTuningJobTrigger: def test_serialize(self, create_hyperparameter_tuning_job_trigger): classpath, kwargs = create_hyperparameter_tuning_job_trigger.serialize() @@ -66,7 +81,7 @@ def test_serialize(self, create_hyperparameter_tuning_job_trigger): impersonation_chain=TEST_IMPERSONATION_CHAIN, ) - @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) def test_get_async_hook(self, mock_async_hook, create_hyperparameter_tuning_job_trigger): hook_expected = mock_async_hook.return_value @@ -87,8 +102,8 @@ def test_get_async_hook(self, mock_async_hook, create_hyperparameter_tuning_job_ (JobState.JOB_STATE_SUCCEEDED, "success"), ], ) - @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) - @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJob")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + @mock.patch("google.cloud.aiplatform_v1.types.HyperparameterTuningJob") async def test_run( self, mock_hpt_job, mock_async_hook, state, status, create_hyperparameter_tuning_job_trigger ): @@ -122,7 +137,7 @@ async def test_run( ) @pytest.mark.asyncio - @patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) + @mock.patch(VERTEX_AI_TRIGGER_PATH.format("HyperparameterTuningJobAsyncHook")) async def test_run_exception(self, mock_async_hook, create_hyperparameter_tuning_job_trigger): mock_async_hook.return_value.wait_hyperparameter_tuning_job.side_effect = AirflowException( "test error" @@ -137,3 +152,127 @@ async def test_run_exception(self, mock_async_hook, create_hyperparameter_tuning "message": "test error", } ) + + +class TestRunPipelineJobTrigger: + def test_serialize(self, run_pipeline_job_trigger): + actual_data = run_pipeline_job_trigger.serialize() + expected_data = ( + "airflow.providers.google.cloud.triggers.vertex_ai.RunPipelineJobTrigger", + { + "conn_id": TEST_CONN_ID, + "project_id": TEST_PROJECT_ID, + "location": TEST_LOCATION, + "job_id": TEST_HPT_JOB_ID, + "poll_interval": TEST_POLL_INTERVAL, + "impersonation_chain": TEST_IMPERSONATION_CHAIN, + }, + ) + actual_data == expected_data + + @pytest.mark.asyncio + async def test_get_async_hook(self, run_pipeline_job_trigger): + hook = await run_pipeline_job_trigger._get_async_hook() + actual_conn_id = hook._hook_kwargs.get("gcp_conn_id") + actual_imp_chain = hook._hook_kwargs.get("impersonation_chain") + assert (actual_conn_id, actual_imp_chain) == (TEST_CONN_ID, TEST_IMPERSONATION_CHAIN) + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_SUCCEEDED, + PipelineState.PIPELINE_STATE_PAUSED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.PipelineJob.to_dict") + @mock.patch( + "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.PipelineJobAsyncHook.get_pipeline_job" + ) + async def test_run_yields_success_event_on_successful_pipeline_state( + self, + mock_get_pipeline_job, + mock_pipeline_job_dict, + run_pipeline_job_trigger, + pipeline_state_value, + ): + mock_get_pipeline_job.return_value = types.PipelineJob(state=pipeline_state_value) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "success", + "message": f"Pipeline job '{TEST_HPT_JOB_ID}' has completed with status {pipeline_state_value.name}.", + "job": {}, + } + ) + actual_event = await run_pipeline_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_FAILED, + PipelineState.PIPELINE_STATE_CANCELLED, + ], + ) + @pytest.mark.asyncio + @mock.patch("google.cloud.aiplatform_v1.types.PipelineJob.to_dict") + @mock.patch( + "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.PipelineJobAsyncHook.get_pipeline_job" + ) + async def test_run_yields_error_event_on_failed_pipeline_state( + self, mock_get_pipeline_job, mock_pipeline_job_dict, pipeline_state_value, run_pipeline_job_trigger + ): + mock_get_pipeline_job.return_value = types.PipelineJob(state=pipeline_state_value) + mock_pipeline_job_dict.return_value = {} + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Pipeline job '{TEST_HPT_JOB_ID}' has completed with status {pipeline_state_value.name}.", + "job": {}, + } + ) + actual_event = await run_pipeline_job_trigger.run().asend(None) + assert actual_event == expected_event + + @pytest.mark.parametrize( + "pipeline_state_value", + [ + PipelineState.PIPELINE_STATE_CANCELLING, + PipelineState.PIPELINE_STATE_PENDING, + PipelineState.PIPELINE_STATE_QUEUED, + PipelineState.PIPELINE_STATE_RUNNING, + PipelineState.PIPELINE_STATE_UNSPECIFIED, + ], + ) + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.PipelineJobAsyncHook.get_pipeline_job" + ) + async def test_run_test_run_loop_is_still_running_if_pipeline_is_running( + self, mock_get_pipeline_job, pipeline_state_value, run_pipeline_job_trigger + ): + mock_get_pipeline_job.return_value = types.PipelineJob(state=pipeline_state_value) + task = asyncio.create_task(run_pipeline_job_trigger.run().__anext__()) + await asyncio.sleep(0.5) + assert task.done() is False + task.cancel() + + @pytest.mark.asyncio + @mock.patch( + "airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job.PipelineJobAsyncHook.get_pipeline_job" + ) + async def test_run_raises_exception(self, mock_get_pipeline_job, run_pipeline_job_trigger): + """ + Tests the DataflowJobAutoScalingEventTrigger does fire if there is an exception. + """ + mock_get_pipeline_job.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + expected_event = TriggerEvent( + { + "status": "error", + "message": f"Exception occurred when trying to run pipeline job {TEST_HPT_JOB_ID}: Test exception", + "job": None, + } + ) + actual_event = await run_pipeline_job_trigger.run().asend(None) + assert expected_event == actual_event