diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 1ead0e7cec88b..608824982cfa3 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -49,6 +49,7 @@ def _initialize_map() -> dict[str, Callable]: from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning from airflow.models.serialized_dag import SerializedDagModel + from airflow.models.skipmixin import SkipMixin from airflow.models.taskinstance import ( TaskInstance, _add_log, @@ -110,6 +111,8 @@ def _initialize_map() -> dict[str, Callable]: DagRun.fetch_task_instance, DagRun._get_log_template, SerializedDagModel.get_serialized_dag, + SkipMixin._skip, + SkipMixin._skip_all_except, TaskInstance._check_and_change_state_before_execution, TaskInstance.get_task_instance, TaskInstance._get_dagrun, diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 9b84bb9b3cce5..6c3d0715b9fbd 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -651,6 +651,7 @@ def get_task_instance( ) @staticmethod + @internal_api_call @provide_session def fetch_task_instance( dag_id: str, diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index 3c89deda1245f..1ed56a43bff2a 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -18,15 +18,17 @@ from __future__ import annotations import warnings +from types import GeneratorType from typing import TYPE_CHECKING, Iterable, Sequence from sqlalchemy import select, update +from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import AirflowException, RemovedInAirflow3Warning from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import TaskInstanceState @@ -60,8 +62,8 @@ def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]: class SkipMixin(LoggingMixin): """A Mixin to skip Tasks Instances.""" + @staticmethod def _set_state_to_skipped( - self, dag_run: DagRun | DagRunPydantic, tasks: Sequence[str] | Sequence[tuple[str, int]], session: Session, @@ -93,12 +95,28 @@ def _set_state_to_skipped( .execution_options(synchronize_session=False) ) - @provide_session def skip( self, dag_run: DagRun | DagRunPydantic, execution_date: DateTime, tasks: Iterable[DAGNode], + map_index: int = -1, + ): + """Facade for compatibility for call to internal API.""" + # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. + task_id: str | None = getattr(self, "task_id", None) + SkipMixin._skip( + dag_run=dag_run, task_id=task_id, execution_date=execution_date, tasks=tasks, map_index=map_index + ) + + @staticmethod + @internal_api_call + @provide_session + def _skip( + dag_run: DagRun | DagRunPydantic, + task_id: str | None, + execution_date: DateTime, + tasks: Iterable[DAGNode], session: Session = NEW_SESSION, map_index: int = -1, ): @@ -143,11 +161,9 @@ def skip( raise ValueError("dag_run is required") task_ids_list = [d.task_id for d in task_list] - self._set_state_to_skipped(dag_run, task_ids_list, session) + SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session) session.commit() - # SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available. - task_id: str | None = getattr(self, "task_id", None) if task_id is not None: from airflow.models.xcom import XCom @@ -165,6 +181,21 @@ def skip_all_except( self, ti: TaskInstance | TaskInstancePydantic, branch_task_ids: None | str | Iterable[str], + ): + """Facade for compatibility for call to internal API.""" + # Ensure we don't serialize a generator object + if branch_task_ids and isinstance(branch_task_ids, GeneratorType): + branch_task_ids = list(branch_task_ids) + SkipMixin._skip_all_except(ti=ti, branch_task_ids=branch_task_ids) + + @classmethod + @internal_api_call + @provide_session + def _skip_all_except( + cls, + ti: TaskInstance | TaskInstancePydantic, + branch_task_ids: None | str | Iterable[str], + session: Session = NEW_SESSION, ): """ Implement the logic for a branching operator. @@ -175,6 +206,7 @@ def skip_all_except( branch_task_ids is stored to XCom so that NotPreviouslySkippedDep knows skipped tasks or newly added tasks should be skipped when they are cleared. """ + log = cls().log # Note: need to catch logger form instance, static logger breaks pytest if isinstance(branch_task_ids, str): branch_task_id_set = {branch_task_ids} elif isinstance(branch_task_ids, Iterable): @@ -195,20 +227,15 @@ def skip_all_except( f"but got {type(branch_task_ids).__name__!r}." ) - self.log.info("Following branch %s", branch_task_id_set) + log.info("Following branch %s", branch_task_id_set) - dag_run = ti.get_dagrun() + dag_run = ti.get_dagrun(session=session) if TYPE_CHECKING: assert isinstance(dag_run, DagRun) assert ti.task - # TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to - # pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition - # does not attempt to serialize the field from/to ORM task = ti.task - dag = task.dag - if TYPE_CHECKING: - assert dag + dag = TaskInstance.ensure_dag(ti, session=session) valid_task_ids = set(dag.task_ids) invalid_task_ids = branch_task_id_set - valid_task_ids @@ -239,15 +266,17 @@ def skip_all_except( skip_tasks = [ (t.task_id, downstream_ti.map_index) for t in downstream_tasks - if (downstream_ti := dag_run.get_task_instance(t.task_id, map_index=ti.map_index)) + if ( + downstream_ti := dag_run.get_task_instance( + t.task_id, map_index=ti.map_index, session=session + ) + ) and t.task_id not in branch_task_id_set ] follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set] - self.log.info("Skipping tasks %s", skip_tasks) - with create_session() as session: - self._set_state_to_skipped(dag_run, skip_tasks, session=session) - # For some reason, session.commit() needs to happen before xcom_push. - # Otherwise the session is not committed. - session.commit() - ti.xcom_push(key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}) + log.info("Skipping tasks %s", skip_tasks) + SkipMixin._set_state_to_skipped(dag_run, skip_tasks, session=session) + ti.xcom_push( + key=XCOM_SKIPMIXIN_KEY, value={XCOM_SKIPMIXIN_FOLLOWED: follow_task_ids}, session=session + ) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 80b3eedbc8433..1dacbe7525ded 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2633,6 +2633,22 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun: return dr + @classmethod + @provide_session + def ensure_dag( + cls, task_instance: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION + ) -> DAG: + """Ensure that task has a dag object associated, might have been removed by serialization.""" + if TYPE_CHECKING: + assert task_instance.task + if task_instance.task.dag is None or task_instance.task.dag is ATTRIBUTE_REMOVED: + task_instance.task.dag = DagBag(read_dags_from_db=True).get_dag( + dag_id=task_instance.dag_id, session=session + ) + if TYPE_CHECKING: + assert task_instance.task.dag + return task_instance.task.dag + @classmethod @internal_api_call @provide_session diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 2fd5fb0fe6e09..465d15130f4de 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -65,7 +65,7 @@ def test_skip(self, mock_now, dag_maker): execution_date=now, state=State.FAILED, ) - SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks, session=session) + SkipMixin().skip(dag_run=dag_run, execution_date=now, tasks=tasks) session.query(TI).filter( TI.dag_id == "dag", @@ -91,7 +91,7 @@ def test_skip_none_dagrun(self, mock_now, dag_maker): RemovedInAirflow3Warning, match=r"Passing an execution_date to `skip\(\)` is deprecated in favour of passing a dag_run", ): - SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks, session=session) + SkipMixin().skip(dag_run=None, execution_date=now, tasks=tasks) session.query(TI).filter( TI.dag_id == "dag", @@ -103,7 +103,7 @@ def test_skip_none_dagrun(self, mock_now, dag_maker): def test_skip_none_tasks(self): session = Mock() - SkipMixin().skip(dag_run=None, execution_date=None, tasks=[], session=session) + SkipMixin().skip(dag_run=None, execution_date=None, tasks=[]) assert not session.query.called assert not session.commit.called