Skip to content

Commit

Permalink
Remove logical_date from APIs and Functions, use run_id instead (apac…
Browse files Browse the repository at this point in the history
…he#42404)

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
sunank200 and uranusjr authored Nov 20, 2024
1 parent 8440016 commit aa7a3b2
Show file tree
Hide file tree
Showing 17 changed files with 231 additions and 310 deletions.
57 changes: 0 additions & 57 deletions airflow/api/common/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
Expand Down Expand Up @@ -87,7 +86,6 @@ def set_state(
*,
tasks: Collection[Operator | tuple[Operator, int]],
run_id: str | None = None,
logical_date: datetime | None = None,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
Expand All @@ -107,7 +105,6 @@ def set_state(
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
``task.dag`` needs to be set
:param run_id: the run_id of the dagrun to start looking from
:param logical_date: the logical date from which to start looking (deprecated)
:param upstream: Mark all parents (upstream tasks)
:param downstream: Mark all siblings (downstream tasks) of task_id
:param future: Mark all future tasks on the interval of the dag up until
Expand All @@ -121,21 +118,12 @@ def set_state(
if not tasks:
return []

if not exactly_one(logical_date, run_id):
raise ValueError("Exactly one of dag_run_id and logical_date must be set")

if logical_date and not timezone.is_localized(logical_date):
raise ValueError(f"Received non-localized date {logical_date}")

task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
if dag is None:
raise ValueError("Received tasks with no DAG")

if logical_date:
run_id = dag.get_dagrun(logical_date=logical_date, session=session).run_id
if not run_id:
raise ValueError("Received tasks with no run_id")

Expand Down Expand Up @@ -279,7 +267,6 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
def set_dag_run_state_to_success(
*,
dag: DAG,
logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand All @@ -290,27 +277,15 @@ def set_dag_run_state_to_success(
Set for a specific logical date and its task instances to success.
:param dag: the DAG of which to alter state
:param logical_date: the logical date from which to start looking(deprecated)
:param run_id: the run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
:raises: ValueError if dag or logical_date is invalid
"""
if not exactly_one(logical_date, run_id):
return []

if not dag:
return []

if logical_date:
if not timezone.is_localized(logical_date):
raise ValueError(f"Received non-localized date {logical_date}")
dag_run = dag.get_dagrun(logical_date=logical_date)
if not dag_run:
raise ValueError(f"DagRun with logical_date: {logical_date} not found")
run_id = dag_run.run_id
if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")
# Mark the dag run to success.
Expand All @@ -333,7 +308,6 @@ def set_dag_run_state_to_success(
def set_dag_run_state_to_failed(
*,
dag: DAG,
logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand All @@ -344,27 +318,14 @@ def set_dag_run_state_to_failed(
Set for a specific logical date and its task instances to failed.
:param dag: the DAG of which to alter state
:param logical_date: the logical date from which to start looking(deprecated)
:param run_id: the DAG run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
:raises: AssertionError if dag or logical_date is invalid
"""
if not exactly_one(logical_date, run_id):
return []
if not dag:
return []

if logical_date:
if not timezone.is_localized(logical_date):
raise ValueError(f"Received non-localized date {logical_date}")
dag_run = dag.get_dagrun(logical_date=logical_date)
if not dag_run:
raise ValueError(f"DagRun with logical_date: {logical_date} not found")
run_id = dag_run.run_id

if not run_id:
raise ValueError(f"Invalid dag_run_id: {run_id}")

Expand Down Expand Up @@ -429,7 +390,6 @@ def __set_dag_run_state_to_running_or_queued(
*,
new_state: DagRunState,
dag: DAG,
logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession,
Expand All @@ -438,28 +398,15 @@ def __set_dag_run_state_to_running_or_queued(
Set the dag run for a specific logical date to running.
:param dag: the DAG of which to alter state
:param logical_date: the logical date from which to start looking
:param run_id: the id of the DagRun
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
"""
res: list[TaskInstance] = []

if not exactly_one(logical_date, run_id):
return res

if not dag:
return res

if logical_date:
if not timezone.is_localized(logical_date):
raise ValueError(f"Received non-localized date {logical_date}")
dag_run = dag.get_dagrun(logical_date=logical_date)
if not dag_run:
raise ValueError(f"DagRun with logical_date: {logical_date} not found")
run_id = dag_run.run_id
if not run_id:
raise ValueError(f"DagRun with run_id: {run_id} not found")
# Mark the dag run to running.
Expand All @@ -474,7 +421,6 @@ def __set_dag_run_state_to_running_or_queued(
def set_dag_run_state_to_running(
*,
dag: DAG,
logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand All @@ -487,7 +433,6 @@ def set_dag_run_state_to_running(
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.RUNNING,
dag=dag,
logical_date=logical_date,
run_id=run_id,
commit=commit,
session=session,
Expand All @@ -498,7 +443,6 @@ def set_dag_run_state_to_running(
def set_dag_run_state_to_queued(
*,
dag: DAG,
logical_date: datetime | None = None,
run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
Expand All @@ -511,7 +455,6 @@ def set_dag_run_state_to_queued(
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.QUEUED,
dag=dag,
logical_date=logical_date,
run_id=run_id,
commit=commit,
session=session,
Expand Down
4 changes: 2 additions & 2 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def _trigger_dag(
run_id = run_id or dag.timetable.generate_run_id(
run_type=DagRunType.MANUAL, logical_date=coerced_logical_date, data_interval=data_interval
)
dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id, logical_date=logical_date)
dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id)

if dag_run:
raise DagRunAlreadyExists(dag_run, logical_date=logical_date, run_id=run_id)
raise DagRunAlreadyExists(dag_run)

run_conf = None
if conf:
Expand Down
16 changes: 3 additions & 13 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,19 +522,10 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION
if not task:
error_message = f"Task ID {task_id} not found"
raise NotFound(error_message)

logical_date = data.get("logical_date")
run_id = data.get("dag_run_id")
if (
logical_date
and (
session.scalars(
select(TI).where(TI.task_id == task_id, TI.dag_id == dag_id, TI.logical_date == logical_date)
).one_or_none()
)
is None
):
raise NotFound(detail=f"Task instance not found for task {task_id!r} on logical_date {logical_date}")
if not run_id:
error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
raise NotFound(detail=error_message)

select_stmt = select(TI).where(
TI.dag_id == dag_id, TI.task_id == task_id, TI.run_id == run_id, TI.map_index == -1
Expand All @@ -547,7 +538,6 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION
tis = dag.set_task_instance_state(
task_id=task_id,
run_id=run_id,
logical_date=logical_date,
state=data["new_state"],
upstream=data["include_upstream"],
downstream=data["include_downstream"],
Expand Down
10 changes: 1 addition & 9 deletions airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from airflow.api_connexion.schemas.trigger_schema import TriggerSchema
from airflow.models import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.utils.helpers import exactly_one
from airflow.utils.state import TaskInstanceState


Expand Down Expand Up @@ -196,8 +195,7 @@ class SetTaskInstanceStateFormSchema(Schema):

dry_run = fields.Boolean(load_default=True)
task_id = fields.Str(required=True)
logical_date = fields.DateTime(validate=validate_istimezone)
dag_run_id = fields.Str()
dag_run_id = fields.Str(required=True)
include_upstream = fields.Boolean(required=True)
include_downstream = fields.Boolean(required=True)
include_future = fields.Boolean(required=True)
Expand All @@ -209,12 +207,6 @@ class SetTaskInstanceStateFormSchema(Schema):
),
)

@validates_schema
def validate_form(self, data, **kwargs):
"""Validate set task instance state form."""
if not exactly_one(data.get("logical_date"), data.get("dag_run_id")):
raise ValidationError("Exactly one of logical_date or dag_run_id must be provided")


class SetSingleTaskInstanceStateFormSchema(Schema):
"""Schema for handling the request of updating state of a single task instance."""
Expand Down
70 changes: 52 additions & 18 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,54 @@ def _generate_temporary_run_id() -> str:
return f"__airflow_temporary_run_{timezone.utcnow().isoformat()}__"


def _fetch_dag_run_from_run_id_or_logical_date_string(
*,
dag_id: str,
value: str,
session: Session,
) -> tuple[DagRun | DagRunPydantic, pendulum.DateTime | None]:
"""
Try to find a DAG run with a given string value.
The string value may be a run ID, or a logical date in string form. We first
try to use it as a run_id; if a run is found, it is returned as-is.
Otherwise, the string value is parsed into a datetime. If that works, it is
used to find a DAG run.
The return value is a two-tuple. The first item is the found DAG run (or
*None* if one cannot be found). The second is the parsed logical date. This
second value can be used to create a new run by the calling function when
one cannot be found here.
"""
if dag_run := DAG.fetch_dagrun(dag_id=dag_id, run_id=value, session=session):
return dag_run, dag_run.logical_date # type: ignore[return-value]
try:
logical_date = timezone.parse(value)
except (ParserError, TypeError):
return dag_run, None
dag_run = session.scalar(
select(DagRun)
.where(DagRun.dag_id == dag_id, DagRun.logical_date == logical_date)
.order_by(DagRun.id.desc())
.limit(1)
)
return dag_run, logical_date


def _get_dag_run(
*,
dag: DAG,
create_if_necessary: CreateIfNecessary,
exec_date_or_run_id: str | None = None,
logical_date_or_run_id: str | None = None,
session: Session | None = None,
) -> tuple[DagRun | DagRunPydantic, bool]:
"""
Try to retrieve a DAG run from a string representing either a run ID or logical date.
This checks DAG runs like this:
1. If the input ``exec_date_or_run_id`` matches a DAG run ID, return the run.
1. If the input ``logical_date_or_run_id`` matches a DAG run ID, return the run.
2. Try to parse the input as a date. If that works, and the resulting
date matches a DAG run's logical date, return the run.
3. If ``create_if_necessary`` is *False* and the input works for neither of
Expand All @@ -112,23 +147,22 @@ def _get_dag_run(
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
logical_date: pendulum.DateTime | None = None
if exec_date_or_run_id:
dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=exec_date_or_run_id, session=session)
if dag_run:
return dag_run, False
with suppress(ParserError, TypeError):
logical_date = timezone.parse(exec_date_or_run_id)
if logical_date:
dag_run = DAG.fetch_dagrun(dag_id=dag.dag_id, logical_date=logical_date, session=session)
if dag_run:
if not logical_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `logical_date_or_run_id` if not `create_if_necessary`.")

logical_date = None
if logical_date_or_run_id:
dag_run, logical_date = _fetch_dag_run_from_run_id_or_logical_date_string(
dag_id=dag.dag_id,
value=logical_date_or_run_id,
session=session,
)
if dag_run is not None:
return dag_run, False
elif not create_if_necessary:
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or logical_date "
f"of {exec_date_or_run_id!r} not found"
f"of {logical_date_or_run_id!r} not found"
)

if logical_date is not None:
Expand All @@ -139,7 +173,7 @@ def _get_dag_run(
if create_if_necessary == "memory":
dag_run = DagRun(
dag_id=dag.dag_id,
run_id=exec_date_or_run_id,
run_id=logical_date_or_run_id,
logical_date=dag_run_logical_date,
data_interval=dag.timetable.infer_manual_data_interval(run_after=dag_run_logical_date),
triggered_by=DagRunTriggeredByType.CLI,
Expand Down Expand Up @@ -178,15 +212,15 @@ def _get_ti_db_access(
raise ValueError(f"Provided task {task.task_id} is not in dag '{dag.dag_id}.")

if not logical_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
raise ValueError("Must provide `logical_date_or_run_id` if not `create_if_necessary`.")
if task.get_needs_expansion():
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
raise RuntimeError("map_index passed to non-mapped task")
dag_run, dr_created = _get_dag_run(
dag=dag,
exec_date_or_run_id=logical_date_or_run_id,
logical_date_or_run_id=logical_date_or_run_id,
create_if_necessary=create_if_necessary,
session=session,
)
Expand Down
Loading

0 comments on commit aa7a3b2

Please sign in to comment.