Skip to content

Commit

Permalink
Improve ExternalTaskSensor Async Implementation (apache#36916)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro authored Jan 25, 2024
1 parent 390eacb commit e9a4bca
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 70 deletions.
102 changes: 36 additions & 66 deletions airflow/sensors/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,24 @@
from typing import TYPE_CHECKING, Any, Callable, Collection, Iterable

import attr
from sqlalchemy import func

from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException, RemovedInAirflow3Warning
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.triggers.external_task import TaskStateTrigger
from airflow.triggers.external_task import WorkflowTrigger
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.helpers import build_airflow_url_with_query
from airflow.utils.sensor_helper import _get_count, _get_external_task_group_task_ids
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import utcnow

if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Session

from airflow.models.baseoperator import BaseOperator
from airflow.models.taskinstancekey import TaskInstanceKey
Expand Down Expand Up @@ -351,29 +348,32 @@ def execute(self, context: Context) -> None:
super().execute(context)
else:
self.defer(
trigger=TaskStateTrigger(
dag_id=self.external_dag_id,
task_id=self.external_task_id,
timeout=self.execution_timeout,
trigger=WorkflowTrigger(
external_dag_id=self.external_dag_id,
external_task_ids=self.external_task_ids,
execution_dates=self._get_dttm_filter(context),
states=self.allowed_states,
trigger_start_time=utcnow(),
poll_interval=self.poll_interval,
allowed_states=self.allowed_states,
poke_interval=self.poll_interval,
soft_fail=self.soft_fail,
),
method_name="execute_complete",
)

def execute_complete(self, context, event=None):
"""Execute when the trigger fires - return immediately."""
if event["status"] == "success":
self.log.info("External task %s has executed successfully.", self.external_task_id)
return None
elif event["status"] == "timeout":
raise AirflowException("Dag was not started within 1 minute, assuming fail.")
self.log.info("External tasks %s has executed successfully.", self.external_task_ids)
elif event["status"] == "skipped":
raise AirflowSkipException("External job has skipped skipping.")
else:
raise AirflowException(
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)
if self.soft_fail:
raise AirflowSkipException("External job has failed skipping.")
else:
raise AirflowException(
"Error occurred while trying to retrieve task status. Please, check the "
"name of executed task and Dag."
)

def _check_for_existence(self, session) -> None:
dag_to_wait = DagModel.get_current(self.external_dag_id, session)
Expand Down Expand Up @@ -412,55 +412,25 @@ def get_count(self, dttm_filter, session, states) -> int:
:param states: task or dag states
:return: count of record against the filters
"""
TI = TaskInstance
DR = DagRun
if not dttm_filter:
return 0

if self.external_task_ids:
count = (
self._count_query(TI, session, states, dttm_filter)
.filter(TI.task_id.in_(self.external_task_ids))
.scalar()
) / len(self.external_task_ids)
elif self.external_task_group_id:
external_task_group_task_ids = self.get_external_task_group_task_ids(session, dttm_filter)
if not external_task_group_task_ids:
count = 0
else:
count = (
self._count_query(TI, session, states, dttm_filter)
.filter(tuple_in_condition((TI.task_id, TI.map_index), external_task_group_task_ids))
.scalar()
) / len(external_task_group_task_ids)
else:
count = self._count_query(DR, session, states, dttm_filter).scalar()
return count

def _count_query(self, model, session, states, dttm_filter) -> Query:
query = session.query(func.count()).filter(
model.dag_id == self.external_dag_id,
model.state.in_(states),
model.execution_date.in_(dttm_filter),
warnings.warn(
"This method is deprecated and will be removed in future.", DeprecationWarning, stacklevel=2
)
return _get_count(
dttm_filter,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
states,
session,
)
return query

def get_external_task_group_task_ids(self, session, dttm_filter):
refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(self.external_dag_id, session)
task_group = refreshed_dag_info.task_group_dict.get(self.external_task_group_id)

if task_group:
group_tasks = session.query(TaskInstance).filter(
TaskInstance.dag_id == self.external_dag_id,
TaskInstance.task_id.in_(task.task_id for task in task_group),
TaskInstance.execution_date.in_(dttm_filter),
)

return [(t.task_id, t.map_index) for t in group_tasks]

# returning default task_id as group_id itself, this will avoid any failure in case of
# 'check_existence=False' and will fail on timeout
return [(self.external_task_group_id, -1)]
warnings.warn(
"This method is deprecated and will be removed in future.", DeprecationWarning, stacklevel=2
)
return _get_external_task_group_task_ids(
dttm_filter, self.external_task_group_id, self.external_dag_id, session
)

def _handle_execution_date_fn(self, context) -> Any:
"""
Expand Down
99 changes: 99 additions & 0 deletions airflow/triggers/external_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

import asyncio
import typing
from typing import Any

from asgiref.sync import sync_to_async
from sqlalchemy import func

from airflow.models import DagRun, TaskInstance
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.sensor_helper import _get_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState
from airflow.utils.timezone import utcnow
Expand All @@ -36,6 +38,103 @@
from airflow.utils.state import DagRunState


class WorkflowTrigger(BaseTrigger):
"""
A trigger to monitor tasks, task group and dag execution in Apache Airflow.
:param external_dag_id: The ID of the external DAG.
:param execution_dates: A list of execution dates for the external DAG.
:param external_task_ids: A collection of external task IDs to wait for.
:param external_task_group_id: The ID of the external task group to wait for.
:param failed_states: States considered as failed for external tasks.
:param skipped_states: States considered as skipped for external tasks.
:param allowed_states: States considered as successful for external tasks.
:param poke_interval: The interval (in seconds) for poking the external tasks.
:param soft_fail: If True, the trigger will not fail the entire DAG on external task failure.
"""

def __init__(
self,
external_dag_id: str,
execution_dates: list,
external_task_ids: typing.Collection[str] | None = None,
external_task_group_id: str | None = None,
failed_states: typing.Iterable[str] | None = None,
skipped_states: typing.Iterable[str] | None = None,
allowed_states: typing.Iterable[str] | None = None,
poke_interval: float = 2.0,
soft_fail: bool = False,
**kwargs,
):
self.external_dag_id = external_dag_id
self.external_task_ids = external_task_ids
self.external_task_group_id = external_task_group_id
self.failed_states = failed_states
self.skipped_states = skipped_states
self.allowed_states = allowed_states
self.execution_dates = execution_dates
self.poke_interval = poke_interval
self.soft_fail = soft_fail
super().__init__(**kwargs)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger param and module path."""
return (
"airflow.triggers.external_task.WorkflowTrigger",
{
"external_dag_id": self.external_dag_id,
"external_task_ids": self.external_task_ids,
"external_task_group_id": self.external_task_group_id,
"failed_states": self.failed_states,
"skipped_states": self.skipped_states,
"allowed_states": self.allowed_states,
"execution_dates": self.execution_dates,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
},
)

async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
while True:
if self.failed_states:
failed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.failed_states,
)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
skipped_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.skipped_states,
)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
allowed_count = _get_count(
self.execution_dates,
self.external_task_ids,
self.external_task_group_id,
self.external_dag_id,
self.allowed_states,
)
if allowed_count == len(self.execution_dates):
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)


class TaskStateTrigger(BaseTrigger):
"""
Waits asynchronously for a task in a different DAG to complete for a specific logical date.
Expand Down
123 changes: 123 additions & 0 deletions airflow/utils/sensor_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, cast

from sqlalchemy import func, select

from airflow.models import DagBag, DagRun, TaskInstance
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition

if TYPE_CHECKING:
from sqlalchemy.orm import Query, Session


@provide_session
def _get_count(
dttm_filter,
external_task_ids,
external_task_group_id,
external_dag_id,
states,
session: Session = NEW_SESSION,
) -> int:
"""
Get the count of records against dttm filter and states.
:param dttm_filter: date time filter for execution date
:param external_task_ids: The list of task_ids
:param external_task_group_id: The ID of the external task group
:param external_dag_id: The ID of the external DAG.
:param states: task or dag states
:param session: airflow session object
"""
TI = TaskInstance
DR = DagRun
if not dttm_filter:
return 0

if external_task_ids:
count = (
session.scalar(
_count_query(TI, states, dttm_filter, external_dag_id, session).filter(
TI.task_id.in_(external_task_ids)
)
)
) / len(external_task_ids)
elif external_task_group_id:
external_task_group_task_ids = _get_external_task_group_task_ids(
dttm_filter, external_task_group_id, external_dag_id, session
)
if not external_task_group_task_ids:
count = 0
else:
count = (
session.scalar(
_count_query(TI, states, dttm_filter, external_dag_id, session).filter(
tuple_in_condition((TI.task_id, TI.map_index), external_task_group_task_ids)
)
)
) / len(external_task_group_task_ids)
else:
count = session.scalar(_count_query(DR, states, dttm_filter, external_dag_id, session))
return cast(int, count)


def _count_query(model, states, dttm_filter, external_dag_id, session: Session) -> Query:
"""
Get the count of records against dttm filter and states.
:param model: The SQLAlchemy model representing the relevant table.
:param states: task or dag states
:param dttm_filter: date time filter for execution date
:param external_dag_id: The ID of the external DAG.
:param session: airflow session object
"""
query = select(func.count()).filter(
model.dag_id == external_dag_id, model.state.in_(states), model.execution_date.in_(dttm_filter)
)
return query


def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, external_dag_id, session):
"""
Get the count of records against dttm filter and states.
:param dttm_filter: date time filter for execution date
:param external_task_group_id: The ID of the external task group
:param external_dag_id: The ID of the external DAG.
:param session: airflow session object
"""
refreshed_dag_info = DagBag(read_dags_from_db=True).get_dag(external_dag_id, session)
task_group = refreshed_dag_info.task_group_dict.get(external_task_group_id)

if task_group:
group_tasks = session.scalars(
select(TaskInstance).filter(
TaskInstance.dag_id == external_dag_id,
TaskInstance.task_id.in_(task.task_id for task in task_group),
TaskInstance.execution_date.in_(dttm_filter),
)
)

return [(t.task_id, t.map_index) for t in group_tasks]

# returning default task_id as group_id itself, this will avoid any failure in case of
# 'check_existence=False' and will fail on timeout
return [(external_task_group_id, -1)]
Loading

0 comments on commit e9a4bca

Please sign in to comment.