Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode for Dataflow sensors #14

63 changes: 62 additions & 1 deletion airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@
from typing import TYPE_CHECKING, Any, Callable, Generator, Sequence, TypeVar, cast

from deprecated import deprecated
from google.cloud.dataflow_v1beta3 import GetJobRequest, Job, JobState, JobsV1Beta3AsyncClient, JobView
from google.cloud.dataflow_v1beta3 import (
GetJobRequest,
Job,
JobState,
JobsV1Beta3AsyncClient,
JobView,
ListJobMessagesRequest,
MessagesV1Beta3AsyncClient,
)
from google.cloud.dataflow_v1beta3.types import JobMessageImportance
from google.cloud.dataflow_v1beta3.types.jobs import ListJobsRequest
from googleapiclient.discovery import build

Expand All @@ -46,6 +55,8 @@

if TYPE_CHECKING:
from google.cloud.dataflow_v1beta3.services.jobs_v1_beta3.pagers import ListJobsAsyncPager
from google.cloud.dataflow_v1beta3.services.messages_v1_beta3.pagers import ListJobMessagesAsyncPager
from google.protobuf.timestamp_pb2 import Timestamp


# This is the default location
Expand Down Expand Up @@ -1352,3 +1363,53 @@ async def list_jobs(
)
page_result: ListJobsAsyncPager = await client.list_jobs(request=request)
return page_result

async def list_job_messages(
self,
job_id: str,
project_id: str | None = PROVIDE_PROJECT_ID,
minimum_importance: int = JobMessageImportance.JOB_MESSAGE_BASIC,
page_size: int | None = None,
page_token: str | None = None,
start_time: Timestamp | None = None,
end_time: Timestamp | None = None,
location: str | None = DEFAULT_DATAFLOW_LOCATION,
) -> ListJobMessagesAsyncPager:
"""
Return ListJobMessagesAsyncPager object from MessagesV1Beta3AsyncClient.

This method wraps around a similar method of MessagesV1Beta3AsyncClient. ListJobMessagesAsyncPager can be iterated
over to extract messages associated with a specific Job ID.

For more details see the MessagesV1Beta3AsyncClient method description at:
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.services.messages_v1_beta3.MessagesV1Beta3AsyncClient

:param job_id: ID of the Dataflow job to get messages about.
:param project_id: Optional. The Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param page_size: Optional. If specified, determines the maximum number of messages to return.
If unspecified, the service may choose an appropriate default, or may return an arbitrarily large number of results.
:param page_token: Optional. If supplied, this should be the value of next_page_token returned by an earlier call.
This will cause the next page of results to be returned.
:param start_time: Optional. If specified, return only messages with timestamps >= start_time.
The default is the job creation time (i.e. beginning of messages).
:param end_time: Optional. If specified, return only messages with timestamps < end_time. The default is the current time.
:param location: Optional. The [regional endpoint] (https://cloud.google.com/dataflow/docs/concepts/regional-endpoints) that contains
the job specified by job_id.
"""
project_id = project_id or (await self.get_project_id())
client = await self.initialize_client(MessagesV1Beta3AsyncClient)
request = ListJobMessagesRequest(
{
"project_id": project_id,
"job_id": job_id,
"minimum_importance": minimum_importance,
"page_size": page_size,
"page_token": page_token,
"start_time": start_time,
"end_time": end_time,
"location": location,
}
)
page_results: ListJobMessagesAsyncPager = await client.list_job_messages(request=request)
return page_results
124 changes: 104 additions & 20 deletions airflow/providers/google/cloud/sensors/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@
"""This module contains a Google Cloud Dataflow sensor."""
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Sequence
from typing import TYPE_CHECKING, Any, Callable, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
DataflowJobStatus,
)
from airflow.providers.google.cloud.triggers.dataflow import (
DataflowJobAutoScalingEventTrigger,
DataflowJobMessagesTrigger,
)
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,21 +204,25 @@ def poke(self, context: Context) -> bool:

class DataflowJobMessagesSensor(BaseSensorOperator):
"""
Checks for the job message in Google Cloud Dataflow.
Checks for job messages associated with a single job in Google Cloud Dataflow.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DataflowJobMessagesSensor`

:param job_id: ID of the job to be checked.
:param callback: callback which is called with list of read job metrics
See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate
:param fail_on_terminal_state: If set to true sensor will raise Exception when
job is in terminal state
:param job_id: ID of the Dataflow job to be checked.
:param callback: a function that can accept a list of serialized job messages.
It can do whatever you want it to do. If the callback function is not provided,
then on successful completion the task will exit with True value.
For more info about the job message content see:
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.JobMessage
:param fail_on_terminal_state: If set to True the sensor will raise an exception when the job is in a terminal state.
No autoscaling events will be returned. The default value is True.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param location: Job location.
:param location: The location of the Dataflow job (for example europe-west1).
If set to None then the value of DEFAULT_DATAFLOW_LOCATION will be used.
See: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -223,6 +232,7 @@ class DataflowJobMessagesSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: If True, run the sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_id",)
Expand All @@ -237,6 +247,7 @@ def __init__(
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -247,14 +258,14 @@ def __init__(
self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.hook: DataflowHook | None = None

def poke(self, context: Context) -> bool:
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

if self.fail_on_terminal_state:
job = self.hook.get_job(
job_id=self.job_id,
Expand All @@ -277,24 +288,62 @@ def poke(self, context: Context) -> bool:

return self.callback(result)

def execute(self, context: Context) -> Any:
"""Airflow runs this method on the worker and defers using the trigger."""
if not self.deferrable:
super().execute(context)
else:
self.defer(
timeout=self.execution_timeout,
trigger=DataflowJobMessagesTrigger(
job_id=self.job_id,
project_id=self.project_id,
location=self.location,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
fail_on_terminal_state=self.fail_on_terminal_state,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str | list]) -> Any:
"""
Execute this method when the task resumes its execution on the worker after deferral.

If the trigger returns an event with success status - passes the event result to the callback function.
Returns True if no callback function is provided.

If the trigger returns an event with error status - raises an exception.
"""
if event["status"] == "success":
self.log.info(event["message"])
return True if self.callback is None else self.callback(event["result"])
if self.soft_fail:
raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.")
raise AirflowException(f"Sensor failed with the following message: {event['message']}")


class DataflowJobAutoScalingEventsSensor(BaseSensorOperator):
"""
Checks for the job autoscaling event in Google Cloud Dataflow.
Checks for autoscaling events associated with a single job in Google Cloud Dataflow.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DataflowJobAutoScalingEventsSensor`

:param job_id: ID of the job to be checked.
:param callback: callback which is called with list of read job metrics
See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/MetricUpdate
:param fail_on_terminal_state: If set to true sensor will raise Exception when
job is in terminal state
:param job_id: ID of the Dataflow job to be checked.
:param callback: a function that can accept a list of serialized autoscaling events.
It can do whatever you want it to do. If the callback function is not provided,
then on successful completion the task will exit with True value.
For more info about the autoscaling event content see:
https://cloud.google.com/python/docs/reference/dataflow/latest/google.cloud.dataflow_v1beta3.types.AutoscalingEvent
:param fail_on_terminal_state: If set to True the sensor will raise an exception when the job is in a terminal state.
No autoscaling events will be returned. The default value is True.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
If set to None or missing, the default project_id from the Google Cloud connection is used.
:param location: Job location.
:param location: The location of the Dataflow job (for example europe-west1).
If set to None then the value of DEFAULT_DATAFLOW_LOCATION will be used.
See: https://cloud.google.com/dataflow/docs/concepts/regional-endpoints
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand All @@ -304,6 +353,7 @@ class DataflowJobAutoScalingEventsSensor(BaseSensorOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param deferrable: If True, run the sensor in the deferrable mode.
"""

template_fields: Sequence[str] = ("job_id",)
Expand All @@ -312,12 +362,13 @@ def __init__(
self,
*,
job_id: str,
callback: Callable,
callback: Callable | None,
fail_on_terminal_state: bool = True,
project_id: str | None = None,
location: str = DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -328,14 +379,14 @@ def __init__(
self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.deferrable = deferrable
self.hook: DataflowHook | None = None

def poke(self, context: Context) -> bool:
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

if self.fail_on_terminal_state:
job = self.hook.get_job(
job_id=self.job_id,
Expand All @@ -357,3 +408,36 @@ def poke(self, context: Context) -> bool:
)

return self.callback(result)

def execute(self, context: Context) -> Any:
"""Airflow runs this method on the worker and defers using the trigger."""
if not self.deferrable:
super().execute(context)
else:
self.defer(
trigger=DataflowJobAutoScalingEventTrigger(
job_id=self.job_id,
project_id=self.project_id,
location=self.location,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
fail_on_terminal_state=self.fail_on_terminal_state,
),
method_name="execute_complete",
)

def execute_complete(self, context: Context, event: dict[str, str | list]) -> Any:
"""
Execute this method when the task resumes its execution on the worker after deferral.

If the trigger returns an event with success status - passes the event result to the callback function.
Returns True if no callback function is provided.

If the trigger returns an event with error status - raises an exception.
"""
if event["status"] == "success":
self.log.info(event["message"])
return True if self.callback is None else self.callback(event["result"])
if self.soft_fail:
raise AirflowSkipException(f"Sensor failed with the following message: {event['message']}.")
raise AirflowException(f"Sensor failed with the following message: {event['message']}")
Loading
Loading