Skip to content

Commit

Permalink
Allow Task Group Ids to be passed as branches in BranchMixIn (apache#…
Browse files Browse the repository at this point in the history
…38883)

* Allow `Task Group Id`s to be passed as branches in BranchMixIn
  • Loading branch information
boraberke authored Jun 3, 2024
1 parent 1436ca4 commit bb06c56
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 29 deletions.
5 changes: 3 additions & 2 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,12 @@ def skip_all_except(
branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or
newly added tasks should be skipped when they are cleared.
"""
self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, str):
branch_task_id_set = {branch_task_ids}
elif isinstance(branch_task_ids, Iterable):
branch_task_id_set = set(branch_task_ids)
invalid_task_ids_type = {
(bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str)
(bti, type(bti).__name__) for bti in branch_task_id_set if not isinstance(bti, str)
}
if invalid_task_ids_type:
raise AirflowException(
Expand All @@ -196,6 +195,8 @@ def skip_all_except(
f"but got {type(branch_task_ids).__name__!r}."
)

self.log.info("Following branch %s", branch_task_id_set)

dag_run = ti.get_dagrun()
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)
Expand Down
39 changes: 35 additions & 4 deletions airflow/operators/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from airflow.models.skipmixin import SkipMixin

if TYPE_CHECKING:
from airflow.models import TaskInstance
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.utils.context import Context


Expand All @@ -34,20 +36,49 @@ class BranchMixIn(SkipMixin):
def do_branch(self, context: Context, branches_to_execute: str | Iterable[str]) -> str | Iterable[str]:
"""Implement the handling of branching including logging."""
self.log.info("Branch into %s", branches_to_execute)
self.skip_all_except(context["ti"], branches_to_execute)
branch_task_ids = self._expand_task_group_roots(context["ti"], branches_to_execute)
self.skip_all_except(context["ti"], branch_task_ids)
return branches_to_execute

def _expand_task_group_roots(
self, ti: TaskInstance | TaskInstancePydantic, branches_to_execute: str | Iterable[str]
) -> Iterable[str]:
"""Expand any task group into its root task ids."""
if TYPE_CHECKING:
assert ti.task

task = ti.task
dag = task.dag
if TYPE_CHECKING:
assert dag

if branches_to_execute is None:
return
elif isinstance(branches_to_execute, str) or not isinstance(branches_to_execute, Iterable):
branches_to_execute = [branches_to_execute]

for branch in branches_to_execute:
if branch in dag.task_group_dict:
tg = dag.task_group_dict[branch]
root_ids = [root.task_id for root in tg.roots]
self.log.info("Expanding task group %s into %s", tg.group_id, root_ids)
yield from root_ids
else:
yield branch


class BaseBranchOperator(BaseOperator, BranchMixIn):
"""
A base class for creating operators with branching functionality, like to BranchPythonOperator.
Users should create a subclass from this operator and implement the function
`choose_branch(self, context)`. This should run whatever business logic
is needed to determine the branch, and return either the task_id for
a single task (as a str) or a list of task_ids.
is needed to determine the branch, and return one of the following:
- A single task_id (as a str)
- A single task_group_id (as a str)
- A list containing a combination of task_ids and task_group_ids
The operator will continue with the returned task_id(s), and all other
The operator will continue with the returned task_id(s) and/or task_group_id(s), and all other
tasks directly downstream of this operator will be skipped.
"""

Expand Down
8 changes: 4 additions & 4 deletions airflow/operators/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ class BranchDateTimeOperator(BaseBranchOperator):
True branch will be returned when ``datetime.datetime.now()`` falls below
``target_upper`` and above ``target_lower``.
:param follow_task_ids_if_true: task id or task ids to follow if
``datetime.datetime.now()`` falls above target_lower and below ``target_upper``.
:param follow_task_ids_if_false: task id or task ids to follow if
``datetime.datetime.now()`` falls below target_lower or above ``target_upper``.
:param follow_task_ids_if_true: task_id, task_group_id, or a list of task_ids and/or task_group_ids
to follow if ``datetime.datetime.now()`` falls above target_lower and below target_upper.
:param follow_task_ids_if_false: task_id, task_group_id, or a list of task_ids and/or task_group_ids
to follow if ``datetime.datetime.now()`` falls below target_lower or above target_upper.
:param target_lower: target lower bound.
:param target_upper: target upper bound.
:param use_task_logical_date: If ``True``, uses task's logical date to compare with targets.
Expand Down
26 changes: 14 additions & 12 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,13 @@ class BranchPythonOperator(PythonOperator, BranchMixIn):
A workflow can "branch" or follow a path after the execution of this task.
It derives the PythonOperator and expects a Python function that returns
a single task_id or list of task_ids to follow. The task_id(s) returned
should point to a task directly downstream from {self}. All other "branches"
or directly downstream tasks are marked with a state of ``skipped`` so that
these paths can't move forward. The ``skipped`` states are propagated
downstream to allow for the DAG state to fill up and the DAG run's state
to be inferred.
a single task_id, a single task_group_id, or a list of task_ids and/or
task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
should point to a task or task group directly downstream from {self}. All
other "branches" or directly downstream tasks are marked with a state of
``skipped`` so that these paths can't move forward. The ``skipped`` states
are propagated downstream to allow for the DAG state to fill up and
the DAG run's state to be inferred.
"""

def execute(self, context: Context) -> Any:
Expand Down Expand Up @@ -861,12 +862,13 @@ class BranchPythonVirtualenvOperator(PythonVirtualenvOperator, BranchMixIn):
A workflow can "branch" or follow a path after the execution of this task in a virtual environment.
It derives the PythonVirtualenvOperator and expects a Python function that returns
a single task_id or list of task_ids to follow. The task_id(s) returned
should point to a task directly downstream from {self}. All other "branches"
or directly downstream tasks are marked with a state of ``skipped`` so that
these paths can't move forward. The ``skipped`` states are propagated
downstream to allow for the DAG state to fill up and the DAG run's state
to be inferred.
a single task_id, a single task_group_id, or a list of task_ids and/or
task_group_ids to follow. The task_id(s) and/or task_group_id(s) returned
should point to a task or task group directly downstream from {self}. All
other "branches" or directly downstream tasks are marked with a state of
``skipped`` so that these paths can't move forward. The ``skipped`` states
are propagated downstream to allow for the DAG state to fill up and
the DAG run's state to be inferred.
.. seealso::
For more information on how to use this operator, take a look at the guide:
Expand Down
6 changes: 4 additions & 2 deletions airflow/operators/weekday.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
# add downstream dependencies as you would do with any branch operator
weekend_check >> [workday, weekend]
:param follow_task_ids_if_true: task id or task ids to follow if criteria met
:param follow_task_ids_if_false: task id or task ids to follow if criteria does not met
:param follow_task_ids_if_true: task_id, task_group_id, or a list of task_ids and/or task_group_ids
to follow if criteria met.
:param follow_task_ids_if_false: task_id, task_group_id, or a list of task_ids and/or task_group_ids
to follow if criteria not met.
:param week_day: Day of the week to check (full name). Optionally, a set
of days can also be provided using a set. Example values:
Expand Down
42 changes: 42 additions & 0 deletions tests/operators/test_branch_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup
from airflow.utils.types import DagRunType

pytestmark = pytest.mark.db_test
Expand All @@ -47,6 +48,11 @@ def choose_branch(self, context):
return ["branch_1", "branch_2"]


class ChooseBranchThree(BaseBranchOperator):
def choose_branch(self, context):
return ["branch_3"]


class TestBranchOperator:
@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -191,3 +197,39 @@ def test_xcom_push(self):
for ti in tis:
if ti.task_id == "make_choice":
assert ti.xcom_pull(task_ids="make_choice") == "branch_1"

def test_with_dag_run_task_groups(self):
self.branch_op = ChooseBranchThree(task_id="make_choice", dag=self.dag)
self.branch_3 = TaskGroup("branch_3", dag=self.dag)
_ = EmptyOperator(task_id="task_1", dag=self.dag, task_group=self.branch_3)
_ = EmptyOperator(task_id="task_2", dag=self.dag, task_group=self.branch_3)

self.branch_1.set_upstream(self.branch_op)
self.branch_2.set_upstream(self.branch_op)
self.branch_3.set_upstream(self.branch_op)

self.dag.clear()

dagrun = self.dag.create_dagrun(
run_type=DagRunType.MANUAL,
start_date=timezone.utcnow(),
execution_date=DEFAULT_DATE,
state=State.RUNNING,
)

self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

tis = dagrun.get_task_instances()
for ti in tis:
if ti.task_id == "make_choice":
assert ti.state == State.SUCCESS
elif ti.task_id == "branch_1":
assert ti.state == State.SKIPPED
elif ti.task_id == "branch_2":
assert ti.state == State.SKIPPED
elif ti.task_id == "branch_3.task_1":
assert ti.state == State.NONE
elif ti.task_id == "branch_3.task_2":
assert ti.state == State.NONE
else:
raise Exception
16 changes: 11 additions & 5 deletions tests/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,10 @@ def f():
return 5

ti = self.create_ti(f)
with pytest.raises(AirflowException, match="must be either None, a task ID, or an Iterable of IDs"):
with pytest.raises(
AirflowException,
match="'branch_task_ids' expected all task IDs are strings.",
):
ti.run()

def test_raise_exception_on_invalid_task_id(self):
Expand Down Expand Up @@ -1440,14 +1443,14 @@ def f(a, b, c=False, d=False):
else:
raise RuntimeError

with pytest.raises(AirflowException, match="but got 'bool'"):
with pytest.raises(AirflowException, match=r"Invalid tasks found: {\((True|False), 'bool'\)}"):
self.run_as_task(f, op_args=[0, 1], op_kwargs={"c": True})

def test_return_false(self):
def f():
return False

with pytest.raises(AirflowException, match="but got 'bool'"):
with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."):
self.run_as_task(f)

def test_context(self):
Expand All @@ -1468,7 +1471,7 @@ def test_with_no_caching(self):
def f():
return False

with pytest.raises(AirflowException, match="but got 'bool'"):
with pytest.raises(AirflowException, match=r"Invalid tasks found: {\(False, 'bool'\)}."):
self.run_as_task(f, do_not_use_caching=True)

def test_with_dag_run(self):
Expand Down Expand Up @@ -1581,7 +1584,10 @@ def f():
return 5

ti = self.create_ti(f)
with pytest.raises(AirflowException, match="must be either None, a task ID, or an Iterable of IDs"):
with pytest.raises(
AirflowException,
match="'branch_task_ids' expected all task IDs are strings.",
):
ti.run()

def test_raise_exception_on_invalid_task_id(self):
Expand Down

0 comments on commit bb06c56

Please sign in to comment.