From 6498f674f93b9f9880b474df0fdff60b02bc4450 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 6 Oct 2023 03:46:58 +0800 Subject: [PATCH] View util refactoring on mapped stuff use cases (#34638) --- airflow/www/views.py | 97 ++++++++++++++++++++------------------------ 1 file changed, 45 insertions(+), 52 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index c38c4f900b2cf..8495f15d07a90 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -18,12 +18,14 @@ from __future__ import annotations import collections.abc +import contextlib import copy import datetime import itertools import json import logging import math +import operator import sys import traceback import warnings @@ -82,6 +84,7 @@ set_state, ) from airflow.auth.managers.models.resource_details import DagAccessEntity +from airflow.compat.functools import cache from airflow.configuration import AIRFLOW_CONFIG, conf from airflow.datasets import Dataset from airflow.exceptions import ( @@ -100,7 +103,7 @@ from airflow.models.dag import get_dataset_triggered_next_run_info from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel -from airflow.models.mappedoperator import MappedOperator +from airflow.models.operator import needs_expansion from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance, TaskInstanceNote from airflow.providers_manager import ProvidersManager @@ -140,7 +143,6 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session - from airflow.models.abstractoperator import AbstractOperator from airflow.models.dag import DAG from airflow.models.operator import Operator @@ -294,7 +296,7 @@ def node_dict(node_id, label, node_class): } -def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): +def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session) -> dict[str, Any]: """ Create a nested dict representation of the DAG's TaskGroup and its children. @@ -321,49 +323,35 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): .order_by(TaskInstance.task_id, TaskInstance.run_id) ) - grouped_tis = {task_id: list(tis) for task_id, tis in itertools.groupby(query, key=lambda ti: ti.task_id)} - - sort_order = conf.get("webserver", "grid_view_sorting_order", fallback="topological") - if sort_order == "topological": - - def sort_children_fn(task_group): - return task_group.topological_sort() - - elif sort_order == "hierarchical_alphabetical": - - def sort_children_fn(task_group): - return task_group.hierarchical_alphabetical_sort() + grouped_tis: dict[str, list[TaskInstance]] = collections.defaultdict( + list, + ((task_id, list(tis)) for task_id, tis in itertools.groupby(query, key=lambda ti: ti.task_id)), + ) - else: + @cache + def get_task_group_children_getter() -> operator.methodcaller: + sort_order = conf.get("webserver", "grid_view_sorting_order", fallback="topological") + if sort_order == "topological": + return operator.methodcaller("topological_sort") + if sort_order == "hierarchical_alphabetical": + return operator.methodcaller("hierarchical_alphabetical_sort") raise AirflowConfigException(f"Unsupported grid_view_sorting_order: {sort_order}") - def task_group_to_grid(item, grouped_tis, *, is_parent_mapped: bool): + def task_group_to_grid(item: Operator | TaskGroup) -> dict[str, Any]: if not isinstance(item, TaskGroup): - def _get_summary(task_instance): - return { - "task_id": task_instance.task_id, - "run_id": task_instance.run_id, - "state": task_instance.state, - "queued_dttm": task_instance.queued_dttm, - "start_date": task_instance.start_date, - "end_date": task_instance.end_date, - "try_number": wwwutils.get_try_count(task_instance._try_number, task_instance.state), - "note": task_instance.note, - } - - def _mapped_summary(ti_summaries): - run_id = None - record = None + def _mapped_summary(ti_summaries: list[TaskInstance]) -> Iterator[dict[str, Any]]: + run_id = "" + record: dict[str, Any] = {} def set_overall_state(record): for state in wwwutils.priority: if state in record["mapped_states"]: record["state"] = state break - if None in record["mapped_states"]: - # When turning the dict into JSON we can't have None as a key, - # so use the string that the UI does. + # When turning the dict into JSON we can't have None as a key, + # so use the string that the UI does. + with contextlib.suppress(KeyError): record["mapped_states"]["no_status"] = record["mapped_states"].pop(None) for ti_summary in ti_summaries: @@ -403,10 +391,22 @@ def set_overall_state(record): set_overall_state(record) yield record - if isinstance(item, MappedOperator) or is_parent_mapped: - instances = list(_mapped_summary(grouped_tis.get(item.task_id, []))) + if item_is_mapped := needs_expansion(item): + instances = list(_mapped_summary(grouped_tis[item.task_id])) else: - instances = list(map(_get_summary, grouped_tis.get(item.task_id, []))) + instances = [ + { + "task_id": task_instance.task_id, + "run_id": task_instance.run_id, + "state": task_instance.state, + "queued_dttm": task_instance.queued_dttm, + "start_date": task_instance.start_date, + "end_date": task_instance.end_date, + "try_number": wwwutils.get_try_count(task_instance._try_number, task_instance.state), + "note": task_instance.note, + } + for task_instance in grouped_tis[item.task_id] + ] setup_teardown_type = {} if item.is_setup is True: @@ -419,7 +419,7 @@ def set_overall_state(record): "instances": instances, "label": item.label, "extra_links": item.extra_links, - "is_mapped": isinstance(item, MappedOperator) or is_parent_mapped, + "is_mapped": item_is_mapped, "has_outlet_datasets": any(isinstance(i, Dataset) for i in (item.outlets or [])), "operator": item.operator_name, "trigger_rule": item.trigger_rule, @@ -428,12 +428,7 @@ def set_overall_state(record): # Task Group task_group = item - group_is_mapped = next(task_group.iter_mapped_task_groups(), None) is not None - - children = [ - task_group_to_grid(child, grouped_tis, is_parent_mapped=group_is_mapped) - for child in sort_children_fn(task_group) - ] + children = [task_group_to_grid(child) for child in get_task_group_children_getter()(item)] def get_summary(dag_run: DagRun): child_instances = [ @@ -532,16 +527,14 @@ def get_mapped_group_summary(run_id: str, mapped_instances: Mapping[int, list[Ta "instances": [], } - if group_is_mapped: - mapped_group_summaries = get_mapped_group_summaries() - + if next(task_group.iter_mapped_task_groups(), None) is not None: return { "id": task_group.group_id, "label": task_group.label, "children": children, "tooltip": task_group.tooltip, - "instances": mapped_group_summaries, - "is_mapped": group_is_mapped, + "instances": get_mapped_group_summaries(), + "is_mapped": True, } group_summaries = [get_summary(dr) for dr in dag_runs] @@ -554,7 +547,7 @@ def get_mapped_group_summary(run_id: str, mapped_instances: Mapping[int, list[Ta "instances": group_summaries, } - return task_group_to_grid(dag.task_group, grouped_tis, is_parent_mapped=False) + return task_group_to_grid(dag.task_group) def get_key_paths(input_dict): @@ -3375,7 +3368,7 @@ def extra_links(self, *, session: Session = NEW_SESSION): if not dag or task_id not in dag.task_ids: return {"url": None, "error": f"can't find dag {dag} or task_id {task_id}"}, 404 - task: AbstractOperator = dag.get_task(task_id) + task = dag.get_task(task_id) link_name = request.args.get("link_name") if link_name is None: return {"url": None, "error": "Link name not passed"}, 400