From aa7a3b24fd0bc267ffca67c740fbe09f892ae39f Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Wed, 20 Nov 2024 19:46:20 +0545 Subject: [PATCH] Remove logical_date from APIs and Functions, use run_id instead (#42404) Co-authored-by: Tzu-ping Chung --- airflow/api/common/mark_tasks.py | 57 ------- airflow/api/common/trigger_dag.py | 4 +- .../endpoints/task_instance_endpoint.py | 16 +- .../schemas/task_instance_schema.py | 10 +- airflow/cli/commands/task_command.py | 70 ++++++--- airflow/exceptions.py | 11 +- airflow/models/dag.py | 145 +++++++++--------- airflow/models/dagrun.py | 22 +-- airflow/www/views.py | 36 +++-- newsfragments/42404.significant.rst | 6 + .../endpoints/test_task_instance_endpoint.py | 68 ++------ .../schemas/test_task_instance_schema.py | 6 +- tests/cli/commands/test_task_command.py | 26 +++- tests/models/test_dag.py | 12 +- tests/models/test_dagrun.py | 7 +- tests/operators/test_trigger_dagrun.py | 19 +-- tests/sensors/test_external_task_sensor.py | 26 +++- 17 files changed, 231 insertions(+), 310 deletions(-) create mode 100644 newsfragments/42404.significant.rst diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index a170e6901a503..b57d25498d267 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -27,7 +27,6 @@ from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone -from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -87,7 +86,6 @@ def set_state( *, tasks: Collection[Operator | tuple[Operator, int]], run_id: str | None = None, - logical_date: datetime | None = None, upstream: bool = False, downstream: bool = False, future: bool = False, @@ -107,7 +105,6 @@ def set_state( :param tasks: the iterable of tasks or (task, map_index) tuples from which to work. ``task.dag`` needs to be set :param run_id: the run_id of the dagrun to start looking from - :param logical_date: the logical date from which to start looking (deprecated) :param upstream: Mark all parents (upstream tasks) :param downstream: Mark all siblings (downstream tasks) of task_id :param future: Mark all future tasks on the interval of the dag up until @@ -121,21 +118,12 @@ def set_state( if not tasks: return [] - if not exactly_one(logical_date, run_id): - raise ValueError("Exactly one of dag_run_id and logical_date must be set") - - if logical_date and not timezone.is_localized(logical_date): - raise ValueError(f"Received non-localized date {logical_date}") - task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks} if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") dag = next(iter(task_dags)) if dag is None: raise ValueError("Received tasks with no DAG") - - if logical_date: - run_id = dag.get_dagrun(logical_date=logical_date, session=session).run_id if not run_id: raise ValueError("Received tasks with no run_id") @@ -279,7 +267,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA def set_dag_run_state_to_success( *, dag: DAG, - logical_date: datetime | None = None, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -290,7 +277,6 @@ def set_dag_run_state_to_success( Set for a specific logical date and its task instances to success. :param dag: the DAG of which to alter state - :param logical_date: the logical date from which to start looking(deprecated) :param run_id: the run_id to start looking from :param commit: commit DAG and tasks to be altered to the database :param session: database session @@ -298,19 +284,8 @@ def set_dag_run_state_to_success( otherwise list of tasks that will be updated :raises: ValueError if dag or logical_date is invalid """ - if not exactly_one(logical_date, run_id): - return [] - if not dag: return [] - - if logical_date: - if not timezone.is_localized(logical_date): - raise ValueError(f"Received non-localized date {logical_date}") - dag_run = dag.get_dagrun(logical_date=logical_date) - if not dag_run: - raise ValueError(f"DagRun with logical_date: {logical_date} not found") - run_id = dag_run.run_id if not run_id: raise ValueError(f"Invalid dag_run_id: {run_id}") # Mark the dag run to success. @@ -333,7 +308,6 @@ def set_dag_run_state_to_success( def set_dag_run_state_to_failed( *, dag: DAG, - logical_date: datetime | None = None, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -344,27 +318,14 @@ def set_dag_run_state_to_failed( Set for a specific logical date and its task instances to failed. :param dag: the DAG of which to alter state - :param logical_date: the logical date from which to start looking(deprecated) :param run_id: the DAG run_id to start looking from :param commit: commit DAG and tasks to be altered to the database :param session: database session :return: If commit is true, list of tasks that have been updated, otherwise list of tasks that will be updated - :raises: AssertionError if dag or logical_date is invalid """ - if not exactly_one(logical_date, run_id): - return [] if not dag: return [] - - if logical_date: - if not timezone.is_localized(logical_date): - raise ValueError(f"Received non-localized date {logical_date}") - dag_run = dag.get_dagrun(logical_date=logical_date) - if not dag_run: - raise ValueError(f"DagRun with logical_date: {logical_date} not found") - run_id = dag_run.run_id - if not run_id: raise ValueError(f"Invalid dag_run_id: {run_id}") @@ -429,7 +390,6 @@ def __set_dag_run_state_to_running_or_queued( *, new_state: DagRunState, dag: DAG, - logical_date: datetime | None = None, run_id: str | None = None, commit: bool = False, session: SASession, @@ -438,7 +398,6 @@ def __set_dag_run_state_to_running_or_queued( Set the dag run for a specific logical date to running. :param dag: the DAG of which to alter state - :param logical_date: the logical date from which to start looking :param run_id: the id of the DagRun :param commit: commit DAG and tasks to be altered to the database :param session: database session @@ -446,20 +405,8 @@ def __set_dag_run_state_to_running_or_queued( otherwise list of tasks that will be updated """ res: list[TaskInstance] = [] - - if not exactly_one(logical_date, run_id): - return res - if not dag: return res - - if logical_date: - if not timezone.is_localized(logical_date): - raise ValueError(f"Received non-localized date {logical_date}") - dag_run = dag.get_dagrun(logical_date=logical_date) - if not dag_run: - raise ValueError(f"DagRun with logical_date: {logical_date} not found") - run_id = dag_run.run_id if not run_id: raise ValueError(f"DagRun with run_id: {run_id} not found") # Mark the dag run to running. @@ -474,7 +421,6 @@ def __set_dag_run_state_to_running_or_queued( def set_dag_run_state_to_running( *, dag: DAG, - logical_date: datetime | None = None, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -487,7 +433,6 @@ def set_dag_run_state_to_running( return __set_dag_run_state_to_running_or_queued( new_state=DagRunState.RUNNING, dag=dag, - logical_date=logical_date, run_id=run_id, commit=commit, session=session, @@ -498,7 +443,6 @@ def set_dag_run_state_to_running( def set_dag_run_state_to_queued( *, dag: DAG, - logical_date: datetime | None = None, run_id: str | None = None, commit: bool = False, session: SASession = NEW_SESSION, @@ -511,7 +455,6 @@ def set_dag_run_state_to_queued( return __set_dag_run_state_to_running_or_queued( new_state=DagRunState.QUEUED, dag=dag, - logical_date=logical_date, run_id=run_id, commit=commit, session=session, diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py index 4a94f990191c4..6891cc1df7830 100644 --- a/airflow/api/common/trigger_dag.py +++ b/airflow/api/common/trigger_dag.py @@ -85,10 +85,10 @@ def _trigger_dag( run_id = run_id or dag.timetable.generate_run_id( run_type=DagRunType.MANUAL, logical_date=coerced_logical_date, data_interval=data_interval ) - dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id, logical_date=logical_date) + dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id) if dag_run: - raise DagRunAlreadyExists(dag_run, logical_date=logical_date, run_id=run_id) + raise DagRunAlreadyExists(dag_run) run_conf = None if conf: diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 7f43c160f3273..00eb51bae10b2 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -522,19 +522,10 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION if not task: error_message = f"Task ID {task_id} not found" raise NotFound(error_message) - - logical_date = data.get("logical_date") run_id = data.get("dag_run_id") - if ( - logical_date - and ( - session.scalars( - select(TI).where(TI.task_id == task_id, TI.dag_id == dag_id, TI.logical_date == logical_date) - ).one_or_none() - ) - is None - ): - raise NotFound(detail=f"Task instance not found for task {task_id!r} on logical_date {logical_date}") + if not run_id: + error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}" + raise NotFound(detail=error_message) select_stmt = select(TI).where( TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == run_id, TI.map_index == -1 @@ -547,7 +538,6 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION tis = dag.set_task_instance_state( task_id=task_id, run_id=run_id, - logical_date=logical_date, state=data["new_state"], upstream=data["include_upstream"], downstream=data["include_downstream"], diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py index 3e864f18652c4..360ecdf277e76 100644 --- a/airflow/api_connexion/schemas/task_instance_schema.py +++ b/airflow/api_connexion/schemas/task_instance_schema.py @@ -29,7 +29,6 @@ from airflow.api_connexion.schemas.trigger_schema import TriggerSchema from airflow.models import TaskInstance from airflow.models.taskinstancehistory import TaskInstanceHistory -from airflow.utils.helpers import exactly_one from airflow.utils.state import TaskInstanceState @@ -196,8 +195,7 @@ class SetTaskInstanceStateFormSchema(Schema): dry_run = fields.Boolean(load_default=True) task_id = fields.Str(required=True) - logical_date = fields.DateTime(validate=validate_istimezone) - dag_run_id = fields.Str() + dag_run_id = fields.Str(required=True) include_upstream = fields.Boolean(required=True) include_downstream = fields.Boolean(required=True) include_future = fields.Boolean(required=True) @@ -209,12 +207,6 @@ class SetTaskInstanceStateFormSchema(Schema): ), ) - @validates_schema - def validate_form(self, data, **kwargs): - """Validate set task instance state form.""" - if not exactly_one(data.get("logical_date"), data.get("dag_run_id")): - raise ValidationError("Exactly one of logical_date or dag_run_id must be provided") - class SetSingleTaskInstanceStateFormSchema(Schema): """Schema for handling the request of updating state of a single task instance.""" diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 396186bf14d2b..2b5a6c18a8089 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -91,11 +91,46 @@ def _generate_temporary_run_id() -> str: return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__" +def _fetch_dag_run_from_run_id_or_logical_date_string( + *, + dag_id: str, + value: str, + session: Session, +) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]: + """ + Try to find a DAG run with a given string value. + + The string value may be a run ID, or a logical date in string form. We first + try to use it as a run_id; if a run is found, it is returned as-is. + + Otherwise, the string value is parsed into a datetime. If that works, it is + used to find a DAG run. + + The return value is a two-tuple. The first item is the found DAG run (or + *None* if one cannot be found). The second is the parsed logical date. This + second value can be used to create a new run by the calling function when + one cannot be found here. + """ + if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value, session=session): + return dag_run, dag_run.logical_date # type: ignore[return-value] + try: + logical_date = timezone.parse(value) + except (ParserError, TypeError): + return dag_run, None + dag_run = session.scalar( + select(DagRun) + .where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date) + .order_by(DagRun.id.desc()) + .limit(1) + ) + return dag_run, logical_date + + def _get_dag_run( *, dag: DAG, create_if_necessary: CreateIfNecessary, - exec_date_or_run_id: str | None = None, + logical_date_or_run_id: str | None = None, session: Session | None = None, ) -> tuple[DagRun | DagRunPydantic, bool]: """ @@ -103,7 +138,7 @@ def _get_dag_run( This checks DAG runs like this: - 1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run. + 1. If the input ``logical_date_or_run_id`` matches a DAG run ID, return the run. 2. Try to parse the input as a date. If that works, and the resulting date matches a DAG run's logical date, return the run. 3. If ``create_if_necessary`` is *False* and the input works for neither of @@ -112,23 +147,22 @@ def _get_dag_run( the logical date; otherwise use it as a run ID and set the logical date to the current time. """ - if not exec_date_or_run_id and not create_if_necessary: - raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") - logical_date: pendulum.DateTime | None = None - if exec_date_or_run_id: - dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=exec_date_or_run_id, session=session) - if dag_run: - return dag_run, False - with suppress(ParserError, TypeError): - logical_date = timezone.parse(exec_date_or_run_id) - if logical_date: - dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, logical_date=logical_date, session=session) - if dag_run: + if not logical_date_or_run_id and not create_if_necessary: + raise ValueError("Must provide `logical_date_or_run_id` if not `create_if_necessary`.") + + logical_date = None + if logical_date_or_run_id: + dag_run, logical_date = _fetch_dag_run_from_run_id_or_logical_date_string( + dag_id=dag.dag_id, + value=logical_date_or_run_id, + session=session, + ) + if dag_run is not None: return dag_run, False elif not create_if_necessary: raise DagRunNotFound( f"DagRun for {dag.dag_id} with run_id or logical_date " - f"of {exec_date_or_run_id!r} not found" + f"of {logical_date_or_run_id!r} not found" ) if logical_date is not None: @@ -139,7 +173,7 @@ def _get_dag_run( if create_if_necessary == "memory": dag_run = DagRun( dag_id=dag.dag_id, - run_id=exec_date_or_run_id, + run_id=logical_date_or_run_id, logical_date=dag_run_logical_date, data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_logical_date), triggered_by=DagRunTriggeredByType.CLI, @@ -178,7 +212,7 @@ def _get_ti_db_access( raise ValueError(f"Provided task {task.task_id} is not in dag '{dag.dag_id}.") if not logical_date_or_run_id and not create_if_necessary: - raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") + raise ValueError("Must provide `logical_date_or_run_id` if not `create_if_necessary`.") if task.get_needs_expansion(): if map_index < 0: raise RuntimeError("No map_index passed to mapped task") @@ -186,7 +220,7 @@ def _get_ti_db_access( raise RuntimeError("map_index passed to non-mapped task") dag_run, dr_created = _get_dag_run( dag=dag, - exec_date_or_run_id=logical_date_or_run_id, + logical_date_or_run_id=logical_date_or_run_id, create_if_necessary=create_if_necessary, session=session, ) diff --git a/airflow/exceptions.py b/airflow/exceptions.py index 3b07b9a6fda96..fee0b5a671d54 100644 --- a/airflow/exceptions.py +++ b/airflow/exceptions.py @@ -230,13 +230,9 @@ class DagRunNotFound(AirflowNotFoundException): class DagRunAlreadyExists(AirflowBadRequest): """Raise when creating a DAG run for DAG which already has DAG run entry.""" - def __init__(self, dag_run: DagRun, logical_date: datetime.datetime, run_id: str) -> None: - super().__init__( - f"A DAG Run already exists for DAG {dag_run.dag_id} at {logical_date} with run id {run_id}" - ) + def __init__(self, dag_run: DagRun) -> None: + super().__init__(f"A DAG Run already exists for DAG {dag_run.dag_id} with run id {dag_run.run_id}") self.dag_run = dag_run - self.logical_date = logical_date - self.run_id = run_id def serialize(self): cls = self.__class__ @@ -249,13 +245,12 @@ def serialize(self): run_id=self.dag_run.run_id, external_trigger=self.dag_run.external_trigger, run_type=self.dag_run.run_type, - logical_date=self.dag_run.logical_date, ) dag_run.id = self.dag_run.id return ( f"{cls.__module__}.{cls.__name__}", (), - {"dag_run": dag_run, "logical_date": self.logical_date, "run_id": self.run_id}, + {"dag_run": dag_run}, ) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index d898e1a52f4b8..8177025c2d8ed 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -27,6 +27,7 @@ from collections import defaultdict from contextlib import ExitStack from datetime import datetime, timedelta +from functools import cache from typing import ( TYPE_CHECKING, Any, @@ -108,7 +109,6 @@ ) from airflow.utils import timezone from airflow.utils.dag_cycle_tester import check_cycle -from airflow.utils.helpers import exactly_one from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, tuple_in_condition, with_row_locks @@ -879,38 +879,20 @@ def get_active_runs(self): @staticmethod @internal_api_call @provide_session - def fetch_dagrun( - dag_id: str, - logical_date: datetime | None = None, - run_id: str | None = None, - session: Session = NEW_SESSION, - ) -> DagRun | DagRunPydantic: + def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION) -> DagRun | DagRunPydantic: """ - Return the dag run for a given logical date or run_id if it exists, otherwise none. + Return the dag run for a given run_id if it exists, otherwise none. :param dag_id: The dag_id of the DAG to find. - :param logical_date: The logical date of the DagRun to find. :param run_id: The run_id of the DagRun to find. :param session: :return: The DagRun if found, otherwise None. """ - if not (logical_date or run_id): - raise TypeError("You must provide either the logical_date or the run_id") - query = select(DagRun) - if logical_date: - query = query.where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date) - if run_id: - query = query.where(DagRun.dag_id == dag_id, DagRun.run_id == run_id) - return session.scalar(query) + return session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id)) @provide_session - def get_dagrun( - self, - logical_date: datetime | None = None, - run_id: str | None = None, - session: Session = NEW_SESSION, - ) -> DagRun | DagRunPydantic: - return DAG.fetch_dagrun(dag_id=self.dag_id, logical_date=logical_date, run_id=run_id, session=session) + def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) -> DagRun | DagRunPydantic: + return DAG.fetch_dagrun(dag_id=self.dag_id, run_id=run_id, session=session) @provide_session def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION): @@ -992,6 +974,7 @@ def get_task_instances( state=state or (), include_dependent_dags=False, exclude_task_ids=(), + exclude_run_ids=None, session=session, ) return session.scalars(cast(Select, query).order_by(DagRun.logical_date)).all() @@ -1007,6 +990,7 @@ def _get_task_instances( state: TaskInstanceState | Sequence[TaskInstanceState], include_dependent_dags: bool, exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, session: Session, dag_bag: DagBag | None = ..., ) -> Iterable[TaskInstance]: ... # pragma: no cover @@ -1023,6 +1007,7 @@ def _get_task_instances( state: TaskInstanceState | Sequence[TaskInstanceState], include_dependent_dags: bool, exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, session: Session, dag_bag: DagBag | None = ..., recursion_depth: int = ..., @@ -1041,6 +1026,7 @@ def _get_task_instances( state: TaskInstanceState | Sequence[TaskInstanceState], include_dependent_dags: bool, exclude_task_ids: Collection[str | tuple[str, int]] | None, + exclude_run_ids: frozenset[str] | None, session: Session, dag_bag: DagBag | None = None, recursion_depth: int = 0, @@ -1098,6 +1084,9 @@ def _get_task_instances( else: tis = tis.where(TaskInstance.state.in_(state)) + if exclude_run_ids: + tis = tis.where(not_(TaskInstance.run_id.in_(exclude_run_ids))) + if include_dependent_dags: # Recursively find external tasks indicated by ExternalTaskMarker from airflow.sensors.external_task import ExternalTaskMarker @@ -1170,6 +1159,7 @@ def _get_task_instances( include_dependent_dags=include_dependent_dags, as_pk_tuple=True, exclude_task_ids=exclude_task_ids, + exclude_run_ids=exclude_run_ids, dag_bag=dag_bag, session=session, recursion_depth=recursion_depth + 1, @@ -1216,7 +1206,6 @@ def set_task_instance_state( *, task_id: str, map_indexes: Collection[int] | None = None, - logical_date: datetime | None = None, run_id: str | None = None, state: TaskInstanceState, upstream: bool = False, @@ -1232,7 +1221,6 @@ def set_task_instance_state( :param task_id: Task ID of the TaskInstance :param map_indexes: Only set TaskInstance if its map_index matches. If None (default), all mapped TaskInstances of the task are set. - :param logical_date: Logical date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to :param upstream: Include all upstream tasks of the given task_id @@ -1243,9 +1231,6 @@ def set_task_instance_state( """ from airflow.api.common.mark_tasks import set_state - if not exactly_one(logical_date, run_id): - raise ValueError("Exactly one of logical_date or run_id must be provided") - task = self.get_task(task_id) task.dag = self @@ -1257,7 +1242,6 @@ def set_task_instance_state( altered = set_state( tasks=tasks_to_set_state, - logical_date=logical_date, run_id=run_id, upstream=upstream, downstream=downstream, @@ -1280,26 +1264,37 @@ def set_task_instance_state( include_upstream=False, ) - if logical_date is None: - dag_run = session.scalars( - select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) - ).one() # Raises an error if not found - resolve_logical_date = dag_run.logical_date - else: - resolve_logical_date = logical_date - - end_date = resolve_logical_date if not future else None - start_date = resolve_logical_date if not past else None - - subdag.clear( - start_date=start_date, - end_date=end_date, - only_failed=True, - session=session, - # Exclude the task itself from being cleared - exclude_task_ids=frozenset({task_id}), - ) - + # Raises an error if not found + dr_id, logical_date = session.execute( + select(DagRun.id, DagRun.logical_date).where( + DagRun.run_id == run_id, DagRun.dag_id == self.dag_id + ) + ).one() + + # Now we want to clear downstreams of tasks that had their state set... + clear_kwargs = { + "only_failed": True, + "session": session, + # Exclude the task itself from being cleared. + "exclude_task_ids": frozenset((task_id,)), + } + if not future and not past: # Simple case 1: we're only dealing with exactly one run. + clear_kwargs["run_id"] = run_id + subdag.clear(**clear_kwargs) + elif future and past: # Simple case 2: we're clearing ALL runs. + subdag.clear(**clear_kwargs) + else: # Complex cases: we may have more than one run, based on a date range. + # Make 'future' and 'past' make some sense when multiple runs exist + # for the same logical date. We order runs by their id and only + # clear runs have larger/smaller ids. + exclude_run_id_stmt = select(DagRun.run_id).where(DagRun.logical_date == logical_date) + if future: + clear_kwargs["start_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id > dr_id) + else: + clear_kwargs["end_date"] = logical_date + exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id < dr_id) + subdag.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)), **clear_kwargs) return altered @provide_session @@ -1307,7 +1302,6 @@ def set_task_group_state( self, *, group_id: str, - logical_date: datetime | None = None, run_id: str | None = None, state: TaskInstanceState, upstream: bool = False, @@ -1321,7 +1315,6 @@ def set_task_group_state( Set TaskGroup to the given state and clear downstream tasks in failed or upstream_failed state. :param group_id: The group_id of the TaskGroup - :param logical_date: Logical date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to :param upstream: Include all upstream tasks of the given task_id @@ -1333,23 +1326,9 @@ def set_task_group_state( """ from airflow.api.common.mark_tasks import set_state - if not exactly_one(logical_date, run_id): - raise ValueError("Exactly one of logical_date or run_id must be provided") - tasks_to_set_state: list[BaseOperator | tuple[BaseOperator, int]] = [] task_ids: list[str] = [] - if logical_date is None: - dag_run = session.scalars( - select(DagRun).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) - ).one() # Raises an error if not found - resolve_logical_date = dag_run.logical_date - else: - resolve_logical_date = logical_date - - end_date = resolve_logical_date if not future else None - start_date = resolve_logical_date if not past else None - task_group_dict = self.task_group.get_task_group_dict() task_group = task_group_dict.get(group_id) if task_group is None: @@ -1357,18 +1336,25 @@ def set_task_group_state( tasks_to_set_state = [task for task in task_group.iter_tasks() if isinstance(task, BaseOperator)] task_ids = [task.task_id for task in task_group.iter_tasks()] dag_runs_query = select(DagRun.id).where(DagRun.dag_id == self.dag_id) - if start_date is None and end_date is None: - dag_runs_query = dag_runs_query.where(DagRun.logical_date == start_date) - else: - if start_date is not None: - dag_runs_query = dag_runs_query.where(DagRun.logical_date >= start_date) - if end_date is not None: - dag_runs_query = dag_runs_query.where(DagRun.logical_date <= end_date) + + @cache + def get_logical_date() -> datetime: + stmt = select(DagRun.logical_date).where(DagRun.run_id == run_id, DagRun.dag_id == self.dag_id) + return session.scalars(stmt).one() # Raises an error if not found + + end_date = None if future else get_logical_date() + start_date = None if past else get_logical_date() + + if future: + dag_runs_query = dag_runs_query.where(DagRun.logical_date <= start_date) + if past: + dag_runs_query = dag_runs_query.where(DagRun.logical_date >= end_date) + if not future and not past: + dag_runs_query = dag_runs_query.where(DagRun.run_id == run_id) with lock_rows(dag_runs_query, session): altered = set_state( tasks=tasks_to_set_state, - logical_date=logical_date, run_id=run_id, upstream=upstream, downstream=downstream, @@ -1416,6 +1402,7 @@ def clear( session: Session = NEW_SESSION, dag_bag: DagBag | None = None, exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), ) -> list[TaskInstance]: ... # pragma: no cover @overload @@ -1433,12 +1420,15 @@ def clear( session: Session = NEW_SESSION, dag_bag: DagBag | None = None, exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), + exclude_run_ids: frozenset[str] | None = frozenset(), ) -> int: ... # pragma: no cover @provide_session def clear( self, task_ids: Collection[str | tuple[str, int]] | None = None, + *, + run_id: str | None = None, start_date: datetime | None = None, end_date: datetime | None = None, only_failed: bool = False, @@ -1449,7 +1439,8 @@ def clear( session: Session = NEW_SESSION, dag_bag: DagBag | None = None, exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(), - ) -> int | list[TaskInstance]: + exclude_run_ids: frozenset[str] | None = frozenset(), + ) -> int | Iterable[TaskInstance]: """ Clear a set of task instances associated with the current dag for a specified date range. @@ -1466,6 +1457,7 @@ def clear( :param dag_bag: The DagBag used to find the dags (Optional) :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) tuples that should not be cleared + :param exclude_run_ids: A set of ``run_id`` or (``run_id``) """ state: list[TaskInstanceState] = [] if only_failed: @@ -1478,12 +1470,13 @@ def clear( task_ids=task_ids, start_date=start_date, end_date=end_date, - run_id=None, + run_id=run_id, state=state, include_dependent_dags=True, session=session, dag_bag=dag_bag, exclude_task_ids=exclude_task_ids, + exclude_run_ids=exclude_run_ids, ) if dry_run: diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index a2327221ad5df..b535d8729ea7b 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -591,31 +591,21 @@ def find( @classmethod @provide_session - def find_duplicate( - cls, - dag_id: str, - run_id: str, - logical_date: datetime, - session: Session = NEW_SESSION, - ) -> DagRun | None: + def find_duplicate(cls, dag_id: str, run_id: str, *, session: Session = NEW_SESSION) -> DagRun | None: """ - Return an existing run for the DAG with a specific run_id or logical date. + Return an existing run for the DAG with a specific run_id. + + *None* is returned if no such DAG run is found. :param dag_id: the dag_id to find duplicates for :param run_id: defines the run id for this dag run - :param logical_date: the logical date :param session: database session """ - return session.scalars( - select(cls).where( - cls.dag_id == dag_id, - or_(cls.run_id == run_id, cls.logical_date == logical_date), - ) - ).one_or_none() + return session.scalars(select(cls).where(cls.dag_id == dag_id, cls.run_id == run_id)).one_or_none() @staticmethod def generate_run_id(run_type: DagRunType, logical_date: datetime) -> str: - """Generate Run ID based on Run Type and Logical Date.""" + """Generate Run ID based on Run Type and logical Date.""" # _Ensure_ run_type is a DagRunType, not just a string from user code return DagRunType(run_type).generate_run_id(logical_date) diff --git a/airflow/www/views.py b/airflow/www/views.py index 805a746fba5ce..e58055b9f1cc3 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1417,7 +1417,10 @@ def rendered_templates(self, session): logger.info("Retrieving rendered templates.") dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) - dag_run = dag.get_dagrun(logical_date=dttm) + dag_run = dag.get_dagrun( + select(DagRun.run_id).where(DagRun.logical_date == dttm).order_by(DagRun.id.desc()).limit(1), + session=session, + ) raw_task = dag.get_task(task_id).prepare_for_execution() no_dagrun = False @@ -1550,10 +1553,11 @@ def rendered_k8s(self, *, session: Session = NEW_SESSION): dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) task = dag.get_task(task_id) + run_id = session.scalar( + select(DagRun.run_id).where(DagRun.logical_date == dttm).order_by(DagRun.id.desc()).limit(1) + ) dag_run = dag.get_dagrun( - run_id=session.scalar( - select(DagRun.run_id).where(DagRun.logical_date == dttm).order_by(DagRun.id.desc()).limit(1) - ), + run_id=run_id, session=session, ) ti = dag_run.get_task_instance(task_id=task.task_id, map_index=map_index, session=session) @@ -2144,7 +2148,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): form=form, ) - dr = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id, logical_date=logical_date) + dr = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id, session=session) if dr: if dr.run_id == run_id: message = f"The run ID {run_id} already exists" @@ -2408,7 +2412,7 @@ def dagrun_clear(self, *, session: Session = NEW_SESSION): only_failed = request.form.get("only_failed") == "true" dag = get_airflow_app().dag_bag.get_dag(dag_id) - dr = dag.get_dagrun(run_id=dag_run_id) + dr = dag.get_dagrun(run_id=dag_run_id, session=session) start_date = dr.logical_date end_date = dr.logical_date @@ -3060,8 +3064,16 @@ def graph(self, dag_id: str, session: Session = NEW_SESSION): flash(f'DAG "{dag_id}" seems to be missing from DagBag.', "error") return redirect(url_for("Airflow.index")) dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, session, dag) - dttm = dt_nr_dr_data["dttm"] - dag_run = dag.get_dagrun(logical_date=dttm) + run_id = session.scalar( + select(DagRun.run_id) + .where(DagRun.logical_date == dt_nr_dr_data["dttm"]) + .order_by(DagRun.id.desc()) + .limit(1) + ) + dag_run = dag.get_dagrun( + run_id=run_id, + session=session, + ) dag_run_id = dag_run.run_id if dag_run else None kwargs = { @@ -3136,7 +3148,13 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): dag = get_airflow_app().dag_bag.get_dag(dag_id, session=session) dt_nr_dr_data = get_date_time_num_runs_dag_runs_form_data(request, session, dag) dttm = dt_nr_dr_data["dttm"] - dag_run = dag.get_dagrun(logical_date=dttm) + run_id = session.scalar( + select(DagRun.run_id).where(DagRun.logical_date == dttm).order_by(DagRun.id.desc()).limit(1) + ) + dag_run = dag.get_dagrun( + run_id=run_id, + session=session, + ) dag_run_id = dag_run.run_id if dag_run else None kwargs = {**sanitize_args(request.args), "dag_id": dag_id, "tab": "gantt", "dag_run_id": dag_run_id} diff --git a/newsfragments/42404.significant.rst b/newsfragments/42404.significant.rst new file mode 100644 index 0000000000000..47546b76ffaed --- /dev/null +++ b/newsfragments/42404.significant.rst @@ -0,0 +1,6 @@ +Removed ``logical_date`` arguments from functions and APIs for DAG run lookups to align with Airflow 3.0. + +The shift towards using ``run_id`` as the sole identifier for DAG runs eliminates the limitations of ``execution_date`` and ``logical_date``, particularly for dynamic DAG runs and cases where multiple runs occur at the same logical time. This change impacts database models, templates, and functions: + +- Removed ``logical_date`` arguments from public APIs and Python functions related to DAG run lookups. +- ``run_id`` is now the exclusive identifier for DAG runs in these contexts. diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index a14a0be33a594..92ad62b67887d 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -1721,6 +1721,7 @@ class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): @mock.patch("airflow.models.dag.DAG.set_task_instance_state") def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, session): self.create_task_instances(session) + run_id = "TEST_DAG_RUN_ID" mock_set_task_instance_state.return_value = ( session.query(TaskInstance) .join(TaskInstance.dag_run) @@ -1734,7 +1735,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi json={ "dry_run": True, "task_id": "print_the_context", - "logical_date": DEFAULT_DATETIME_1.isoformat(), + "dag_run_id": run_id, "include_upstream": True, "include_downstream": True, "include_future": True, @@ -1757,8 +1758,7 @@ def test_should_assert_call_mocked_api(self, mock_set_task_instance_state, sessi mock_set_task_instance_state.assert_called_once_with( commit=False, downstream=True, - run_id=None, - logical_date=DEFAULT_DATETIME_1, + run_id=run_id, future=True, past=True, state="failed", @@ -1807,7 +1807,6 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ commit=False, downstream=True, run_id=run_id, - logical_date=None, future=True, past=True, state="failed", @@ -1820,7 +1819,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ "error, code, payload", [ [ - "{'_schema': ['Exactly one of logical_date or dag_run_id must be provided']}", + "{'dag_run_id': ['Missing data for required field.']}", 400, { "dry_run": True, @@ -1833,9 +1832,8 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ }, ], [ - "Task instance not found for task 'print_the_context' on logical_date " - "2021-01-01 00:00:00+00:00", - 404, + "{'dag_run_id': ['Missing data for required field.'], 'logical_date': ['Unknown field.']}", + 400, { "dry_run": True, "task_id": "print_the_context", @@ -1862,7 +1860,7 @@ def test_should_assert_call_mocked_api_when_run_id(self, mock_set_task_instance_ }, ], [ - "{'_schema': ['Exactly one of logical_date or dag_run_id must be provided']}", + "{'logical_date': ['Unknown field.']}", 400, { "dry_run": True, @@ -1928,7 +1926,7 @@ def test_should_raise_404_not_found_dag(self): json={ "dry_run": True, "task_id": "print_the_context", - "logical_date": DEFAULT_DATETIME_1.isoformat(), + "dag_run_id": "random_run_id", "include_upstream": True, "include_downstream": True, "include_future": True, @@ -1941,14 +1939,14 @@ def test_should_raise_404_not_found_dag(self): @mock.patch("airflow.models.dag.DAG.set_task_instance_state") def test_should_raise_not_found_if_run_id_is_wrong(self, mock_set_task_instance_state, session): self.create_task_instances(session) - date = DEFAULT_DATETIME_1 + dt.timedelta(days=1) + run_id = "random_run_id" response = self.client.post( "/api/v1/dags/example_python_operator/updateTaskInstancesState", environ_overrides={"REMOTE_USER": "test"}, json={ "dry_run": True, "task_id": "print_the_context", - "logical_date": date.isoformat(), + "dag_run_id": run_id, "include_upstream": True, "include_downstream": True, "include_future": True, @@ -1958,7 +1956,7 @@ def test_should_raise_not_found_if_run_id_is_wrong(self, mock_set_task_instance_ ) assert response.status_code == 404 assert response.json["detail"] == ( - f"Task instance not found for task 'print_the_context' on logical_date {date}" + f"Task instance not found for task 'print_the_context' on DAG run with ID '{run_id}'" ) assert mock_set_task_instance_state.call_count == 0 @@ -1969,7 +1967,7 @@ def test_should_raise_404_not_found_task(self): json={ "dry_run": True, "task_id": "INVALID_TASK", - "logical_date": DEFAULT_DATETIME_1.isoformat(), + "dag_run_id": "TEST_DAG_RUN_ID", "include_upstream": True, "include_downstream": True, "include_future": True, @@ -1979,48 +1977,6 @@ def test_should_raise_404_not_found_task(self): ) assert response.status_code == 404 - @pytest.mark.parametrize( - "payload, expected", - [ - ( - { - "dry_run": True, - "task_id": "print_the_context", - "logical_date": "2020-11-10T12:42:39.442973", - "include_upstream": True, - "include_downstream": True, - "include_future": True, - "include_past": True, - "new_state": "failed", - }, - "Naive datetime is disallowed", - ), - ( - { - "dry_run": True, - "task_id": "print_the_context", - "logical_date": "2020-11-10T12:4opfo", - "include_upstream": True, - "include_downstream": True, - "include_future": True, - "include_past": True, - "new_state": "failed", - }, - "{'logical_date': ['Not a valid datetime.']}", - ), - ], - ) - @provide_session - def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, session): - self.create_task_instances(session) - response = self.client.post( - "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": "test"}, - json=payload, - ) - assert response.status_code == 400 - assert response.json["detail"] == expected - class TestPatchTaskInstance(TestTaskInstanceEndpoint): ENDPOINT_URL = ( diff --git a/tests/api_connexion/schemas/test_task_instance_schema.py b/tests/api_connexion/schemas/test_task_instance_schema.py index 5297830dca011..9080572314fdc 100644 --- a/tests/api_connexion/schemas/test_task_instance_schema.py +++ b/tests/api_connexion/schemas/test_task_instance_schema.py @@ -166,7 +166,7 @@ class TestSetTaskInstanceStateFormSchema: current_input = { "dry_run": True, "task_id": "print_the_context", - "logical_date": "2020-01-01T00:00:00+00:00", + "dag_run_id": "test_run_id", "include_upstream": True, "include_downstream": True, "include_future": True, @@ -178,7 +178,7 @@ def test_success(self): result = set_task_instance_state_form.load(self.current_input) expected_result = { "dry_run": True, - "logical_date": dt.datetime(2020, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(0), "+0000")), + "dag_run_id": "test_run_id", "include_downstream": True, "include_future": True, "include_past": True, @@ -194,7 +194,7 @@ def test_dry_run_is_optional(self): result = set_task_instance_state_form.load(self.current_input) expected_result = { "dry_run": True, - "logical_date": dt.datetime(2020, 1, 1, 0, 0, tzinfo=dt.timezone(dt.timedelta(0), "+0000")), + "dag_run_id": "test_run_id", "include_downstream": True, "include_future": True, "include_past": True, diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index ca8cadd9368bf..f78b53c5a769e 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -34,7 +34,6 @@ import pendulum import pytest -import sqlalchemy.exc from airflow.cli import cli_parser from airflow.cli.commands import task_command @@ -228,18 +227,29 @@ def test_cli_test_different_path(self, session, tmp_path): assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix() @mock.patch("airflow.cli.commands.task_command.select") - @mock.patch("sqlalchemy.orm.session.Session.scalars") - @mock.patch("airflow.cli.commands.task_command.DagRun") - def test_task_render_with_custom_timetable(self, mock_dagrun, mock_scalars, mock_select): + @mock.patch("sqlalchemy.orm.session.Session.scalar") + def test_task_render_with_custom_timetable(self, mock_scalar, mock_select): """ - when calling `tasks render` on dag with custom timetable, the DagRun object should be created with - data_intervals. + Test that the `tasks render` CLI command queries the database correctly + for a DAG with a custom timetable. Verifies that a query is executed to + fetch the appropriate DagRun and that the database interaction occurs as expected. """ - mock_scalars.side_effect = sqlalchemy.exc.NoResultFound + from sqlalchemy import select + + from airflow.models.dagrun import DagRun + + mock_query = ( + select(DagRun).where(DagRun.dag_id == "example_workday_timetable").order_by(DagRun.id.desc()) + ) + mock_select.return_value = mock_query + + mock_scalar.return_value = None + task_command.task_render( self.parser.parse_args(["tasks", "render", "example_workday_timetable", "run_this", "2022-01-01"]) ) - assert "data_interval" in mock_dagrun.call_args.kwargs + + mock_select.assert_called_once() @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") def test_test_with_existing_dag_run(self, caplog): diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index f2128b205b4d2..1f946d027dfe1 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -2613,13 +2613,10 @@ def return_num(num): @pytest.mark.parametrize( - "run_id, logical_date", - [ - (None, datetime_tz(2020, 1, 1)), - ("test-run-id", None), - ], + "run_id", + ["test-run-id"], ) -def test_set_task_instance_state(run_id, logical_date, session, dag_maker): +def test_set_task_instance_state(run_id, session, dag_maker): """Test that set_task_instance_state updates the TaskInstance state and clear downstream failed""" start_date = datetime_tz(2020, 1, 1) @@ -2633,7 +2630,6 @@ def test_set_task_instance_state(run_id, logical_date, session, dag_maker): dagrun = dag_maker.create_dagrun( run_id=run_id, - logical_date=logical_date, state=State.FAILED, run_type=DagRunType.SCHEDULED, ) @@ -2654,12 +2650,12 @@ def get_ti_from_db(task): get_ti_from_db(task_3).state = State.UPSTREAM_FAILED get_ti_from_db(task_4).state = State.FAILED get_ti_from_db(task_5).state = State.SKIPPED + session.flush() altered = dag.set_task_instance_state( task_id=task_1.task_id, run_id=run_id, - logical_date=logical_date, state=State.SUCCESS, session=session, ) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 39d6d56ef3762..e26fe620c0b01 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -220,10 +220,9 @@ def test_dagrun_find_duplicate(self, session): session.commit() - assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, logical_date=now) is not None - assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id, logical_date=None) is not None - assert DagRun.find_duplicate(dag_id=dag_id, run_id=None, logical_date=now) is not None - assert DagRun.find_duplicate(dag_id=dag_id, run_id=None, logical_date=None) is None + assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id) is not None + assert DagRun.find_duplicate(dag_id=dag_id, run_id=dag_id) is not None + assert DagRun.find_duplicate(dag_id=dag_id, run_id=None) is None def test_dagrun_success_when_all_skipped(self, session): """ diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index ad1fca15ea46d..85daeaed275e0 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -215,12 +215,14 @@ def test_trigger_dagrun_twice(self, dag_maker): def test_trigger_dagrun_with_scheduled_dag_run(self, dag_maker): """Test TriggerDagRunOperator with custom logical_date and scheduled dag_run.""" utc_now = timezone.utcnow() + run_id = f"scheduled__{utc_now.isoformat()}" with dag_maker( TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True ) as dag: task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_logical_date", trigger_dag_id=TRIGGERED_DAG_ID, + trigger_run_id=run_id, logical_date=utc_now, poke_interval=1, reset_dag_run=True, @@ -496,23 +498,6 @@ def test_trigger_dagrun_triggering_itself(self, dag_maker): triggered_dag_run = dagruns[1] assert triggered_dag_run.state == State.QUEUED - def test_trigger_dagrun_triggering_itself_with_logical_date(self, dag_maker): - """Test TriggerDagRunOperator that triggers itself with logical date, - fails with DagRunAlreadyExists""" - logical_date = DEFAULT_DATE - with dag_maker( - TEST_DAG_ID, default_args={"owner": "airflow", "start_date": DEFAULT_DATE}, serialized=True - ) as dag: - task = TriggerDagRunOperator( - task_id="test_task", - trigger_dag_id=TEST_DAG_ID, - logical_date=logical_date, - ) - self.re_sync_triggered_dag_to_db(dag, dag_maker) - dag_maker.create_dagrun() - with pytest.raises(DagRunAlreadyExists): - task.run(start_date=logical_date, end_date=logical_date) - @pytest.mark.skip_if_database_isolation_mode # Known to be broken in db isolation mode def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self, dag_maker): """Test TriggerDagRunOperator with wait_for_completion.""" diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index 4cd7b5e5f8c6f..e03ceeed01960 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -1295,9 +1295,10 @@ def run_tasks( start_date=logical_date, run_type=DagRunType.MANUAL, session=session, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), + data_interval=(logical_date, logical_date), triggered_by=triggered_by, ) + runs[dag.dag_id] = dagrun # we use sorting by task_id here because for the test DAG structure of ours # this is equivalent to topological sort. It would not work in general case # but it works for our case because we specifically constructed test DAGS @@ -1373,9 +1374,22 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session): run_tasks(dag_bag, logical_date=day_1) run_tasks(dag_bag, logical_date=day_2) + from sqlalchemy import select + + run_ids = [] # Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared. for dag, logical_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]): - dagrun = dag.get_dagrun(logical_date=logical_date, session=session) + run_id = ( + select(DagRun.run_id) + .where(DagRun.logical_date == logical_date) + .order_by(DagRun.id.desc()) + .limit(1) + ) + run_ids.append(run_id) + dagrun = dag.get_dagrun( + run_id=run_id, + session=session, + ) dagrun.set_state(State.SUCCESS) session.flush() @@ -1385,10 +1399,10 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session): # Assert that dagruns of all the affected dags are set to QUEUED after tasks are cleared. # Unaffected dagruns should be left as SUCCESS. - dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(logical_date=day_1, session=session) - dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(logical_date=day_2, session=session) - dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(logical_date=day_1, session=session) - dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(logical_date=day_2, session=session) + dagrun_0_1 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[0], session=session) + dagrun_0_2 = dag_bag.get_dag("parent_dag_0").get_dagrun(run_id=run_ids[1], session=session) + dagrun_1_1 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[2], session=session) + dagrun_1_2 = dag_bag.get_dag("child_dag_1").get_dagrun(run_id=run_ids[3], session=session) assert dagrun_0_1.state == State.QUEUED assert dagrun_0_2.state == State.QUEUED