Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the deferrable mode to RunPipelineJobOperator #17

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.api_core.retry_async import AsyncRetry
from google.api_core.retry import AsyncRetry, Retry
from google.cloud.aiplatform_v1.services.job_service.pagers import ListHyperparameterTuningJobsPager


Expand Down
128 changes: 107 additions & 21 deletions airflow/providers/google/cloud/hooks/vertex_ai/pipeline_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,21 @@
from google.api_core.client_options import ClientOptions
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform import PipelineJob
from google.cloud.aiplatform_v1 import PipelineServiceClient
from google.cloud.aiplatform_v1 import (
GetPipelineJobRequest,
PipelineServiceAsyncClient,
PipelineServiceClient,
types,
)

from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook

if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
from google.cloud.aiplatform.metadata import experiment_resources
from google.api_core.retry import AsyncRetry, Retry
from google.auth.credentials import Credentials
from google.cloud.aiplatform_v1.services.pipeline_service.pagers import ListPipelineJobsPager


Expand Down Expand Up @@ -101,11 +106,6 @@ def get_pipeline_job_object(
failure_policy=failure_policy,
)

@staticmethod
def extract_pipeline_job_id(obj: dict) -> str:
"""Return unique id of the pipeline_job."""
return obj["name"].rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: float | None = None):
"""Wait for long-lasting operation to complete."""
try:
Expand All @@ -129,7 +129,7 @@ def create_pipeline_job(
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> PipelineJob:
) -> types.PipelineJob:
"""
Create a PipelineJob. A PipelineJob will run immediately when created.

Expand Down Expand Up @@ -178,11 +178,17 @@ def run_pipeline_job(
service_account: str | None = None,
network: str | None = None,
create_request_timeout: float | None = None,
experiment: str | experiment_resources.Experiment | None = None,
e-galan marked this conversation as resolved.
Show resolved Hide resolved
# END: run param
sync=True,
) -> PipelineJob:
"""
Run PipelineJob and monitor the job until completion.
Create and run a PipelineJob.

If sync is True the method will keep running until the job's completion.
If sync is False the method will exit after setting required resources.

For more info please see:
https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.PipelineJob#google_cloud_aiplatform_PipelineJob_run

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
Expand Down Expand Up @@ -224,10 +230,9 @@ def run_pipeline_job(
Private services access must already be configured for the network. If left unspecified, the
network set in aiplatform.init will be used. Otherwise, the job is not peered with any network.
:param create_request_timeout: Optional. The timeout for the create request in seconds.
:param experiment: Optional. The Vertex AI experiment name or instance to associate to this
PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as
metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to
the current Experiment Run.
:param sync: Whether to execute this method synchronously.
If False, this method will unblock, and it will be executed in a concurrent Future.
The default is True.
"""
self._pipeline_job = self.get_pipeline_job_object(
display_name=display_name,
Expand All @@ -243,15 +248,18 @@ def run_pipeline_job(
location=region,
failure_policy=failure_policy,
)

self._pipeline_job.submit(
self._pipeline_job.run(
service_account=service_account,
network=network,
create_request_timeout=create_request_timeout,
experiment=experiment,
sync=sync,
)

self._pipeline_job.wait()
if sync:
self._pipeline_job.wait()
else:
self._pipeline_job._wait_for_resource_creation()

return self._pipeline_job

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -263,7 +271,7 @@ def get_pipeline_job(
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
) -> PipelineJob:
) -> types.PipelineJob:
"""
Get a PipelineJob.

Expand Down Expand Up @@ -407,3 +415,81 @@ def delete_pipeline_job(
metadata=metadata,
)
return result


class PipelineJobAsyncHook(GoogleBaseAsyncHook):
"""Asynchronous hook for Google Cloud Vertex AI Pipeline Job APIs."""

sync_hook_class = PipelineJobHook

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)

async def get_credentials(self) -> Credentials:
credentials = (await self.get_sync_hook()).get_credentials()
return credentials

async def get_project_id(self) -> str:
sync_hook = await self.get_sync_hook()
return sync_hook.project_id

async def get_project_location(self) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
async def get_project_location(self) -> str:
async def get_location(self) -> str:

sync_hook = await self.get_sync_hook()
return sync_hook.location

async def get_pipeline_service_client(
self,
region: str | None = None,
) -> PipelineServiceAsyncClient:
"""Return PipelineServiceAsyncClient object."""
if region and region != "global":
client_options = ClientOptions(api_endpoint=f"{region}-aiplatform.googleapis.com:443")
else:
client_options = ClientOptions()
return PipelineServiceAsyncClient(
credentials=await self.get_credentials(),
client_info=CLIENT_INFO,
client_options=client_options,
)

async def get_pipeline_job(
self,
project_id: str,
location: str,
job_id: str,
retry: AsyncRetry | _MethodDefault = DEFAULT,
timeout: float | _MethodDefault | None = DEFAULT,
metadata: Sequence[tuple[str, str]] = (),
) -> types.PipelineJob:
"""
Get a PipelineJob message from PipelineServiceAsyncClient.

:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param location: Required. The ID of the Google Cloud region that the service belongs to.
:param job_id: Required. The ID of the PipelineJob resource.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = await self.get_pipeline_service_client(region=location)
request = self.get_pipeline_job_request(project=project_id, location=location, job=job_id)
response: types.PipelineJob = await client.get_pipeline_job(
request=request,
retry=retry,
timeout=timeout,
metadata=metadata,
)
return response

@staticmethod
def get_pipeline_job_request(project: str, location: str, job: str) -> GetPipelineJobRequest:
name = f"projects/{project}/locations/{location}/pipelineJobs/{job}"
return GetPipelineJobRequest(name=name)
76 changes: 60 additions & 16 deletions airflow/providers/google/cloud/operators/vertex_ai/pipeline_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,21 @@

from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
from google.cloud.aiplatform_v1.types import PipelineJob
from google.cloud.aiplatform_v1 import types

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.vertex_ai.pipeline_job import PipelineJobHook
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIPipelineJobLink,
VertexAIPipelineJobListLink,
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.vertex_ai import RunPipelineJobTrigger

if TYPE_CHECKING:
from google.api_core.retry import Retry
from google.cloud.aiplatform.metadata import experiment_resources
from google.cloud.aiplatform import PipelineJob

from airflow.utils.context import Context

Expand Down Expand Up @@ -82,11 +85,9 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
Private services access must already be configured for the network. If left unspecified, the
network set in aiplatform.init will be used. Otherwise, the job is not peered with any network.
:param create_request_timeout: Optional. The timeout for the create request in seconds.
:param experiment: Optional. The Vertex AI experiment name or instance to associate to this
PipelineJob. Metrics produced by the PipelineJob as system.Metric Artifacts will be associated as
metrics to the current Experiment Run. Pipeline parameters will be associated as parameters to
the current Experiment Run.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param sync: Whether to execute this method synchronously. If False, this method will unblock, and it
will be executed in a concurrent Future. The default is True.
Comment on lines +95 to +96
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haven't we agreed to remove this parameter?

:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
Expand All @@ -95,6 +96,10 @@ class RunPipelineJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: If True, run the task in the deferrable mode.
Note that it requires calling the operator with `sync=False` parameter.
:param poll_interval: Time (seconds) to wait between two consecutive calls to check the job.
The default is 300 seconds.
"""

template_fields = [
Expand Down Expand Up @@ -123,9 +128,11 @@ def __init__(
service_account: str | None = None,
network: str | None = None,
create_request_timeout: float | None = None,
experiment: str | experiment_resources.Experiment | None = None,
gcp_conn_id: str = "google_cloud_default",
sync: bool = True,
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: int = 5 * 60,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -144,18 +151,21 @@ def __init__(
self.service_account = service_account
self.network = network
self.create_request_timeout = create_request_timeout
self.experiment = experiment
self.gcp_conn_id = gcp_conn_id
self.sync = sync
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.poll_interval = poll_interval
self.hook: PipelineJobHook | None = None

def execute(self, context: Context):
self.validate_sync_parameter()
self.log.info("Running Pipeline job")
self.hook = PipelineJobHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
result = self.hook.run_pipeline_job(
pipeline_job_obj: PipelineJob = self.hook.run_pipeline_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
Expand All @@ -171,22 +181,56 @@ def execute(self, context: Context):
service_account=self.service_account,
network=self.network,
create_request_timeout=self.create_request_timeout,
experiment=self.experiment,
sync=self.sync,
)

pipeline_job = result.to_dict()
pipeline_job_id = self.hook.extract_pipeline_job_id(pipeline_job)
if self.deferrable:
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_obj.job_id)
self.defer(
trigger=RunPipelineJobTrigger(
conn_id=self.gcp_conn_id,
project_id=self.project_id,
location=pipeline_job_obj.location,
job_id=pipeline_job_obj.job_id,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
pipeline_job = pipeline_job_obj.to_dict()
pipeline_job_id = self.extract_pipeline_job_id(pipeline_job)
self.log.info("Pipeline job was created. Job id: %s", pipeline_job_id)

self.xcom_push(context, key="pipeline_job_id", value=pipeline_job_id)
VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=pipeline_job_id)
return pipeline_job

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
if event["status"] == "error":
raise AirflowException(event["message"])
job_id = self.extract_pipeline_job_id(event["job"])
self.xcom_push(context, key="pipeline_job_id", value=job_id)
VertexAIPipelineJobLink.persist(context=context, task_instance=self, pipeline_id=job_id)
return event["job"]

def on_kill(self) -> None:
"""Act as a callback called when the operator is killed; cancel any running job."""
self.log.warning("This is hook on kill: %s", self.hook)
if self.hook:
self.hook.cancel_pipeline_job()

def validate_sync_parameter(self) -> None:
if self.deferrable and self.sync:
raise AirflowException(
"Deferrable mode can be used only with sync=False option. "
"If you are willing to run the operator in deferrable mode, please, set sync=False. "
"Otherwise, disable deferrable mode `deferrable=False`."
)
return

@staticmethod
def extract_pipeline_job_id(obj: dict) -> str:
"""Return unique id of a pipeline job from its name."""
return obj["name"].rpartition("/")[-1]


class GetPipelineJobOperator(GoogleCloudBaseOperator):
"""
Expand Down Expand Up @@ -261,7 +305,7 @@ def execute(self, context: Context):
context=context, task_instance=self, pipeline_id=self.pipeline_job_id
)
self.log.info("Pipeline job was gotten.")
return PipelineJob.to_dict(result)
return types.PipelineJob.to_dict(result)
except NotFound:
self.log.info("The Pipeline job %s does not exist.", self.pipeline_job_id)

Expand Down Expand Up @@ -389,7 +433,7 @@ def execute(self, context: Context):
metadata=self.metadata,
)
VertexAIPipelineJobListLink.persist(context=context, task_instance=self)
return [PipelineJob.to_dict(result) for result in results]
return [types.PipelineJob.to_dict(result) for result in results]


class DeletePipelineJobOperator(GoogleCloudBaseOperator):
Expand Down
Loading
Loading