Skip to content

Commit

Permalink
View util refactoring on mapped stuff use cases (apache#34638)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Oct 5, 2023
1 parent 1047ff8 commit 6498f67
Showing 1 changed file with 45 additions and 52 deletions.
97 changes: 45 additions & 52 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from __future__ import annotations

import collections.abc
import contextlib
import copy
import datetime
import itertools
import json
import logging
import math
import operator
import sys
import traceback
import warnings
Expand Down Expand Up @@ -82,6 +84,7 @@
set_state,
)
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.compat.functools import cache
from airflow.configuration import AIRFLOW_CONFIG, conf
from airflow.datasets import Dataset
from airflow.exceptions import (
Expand All @@ -100,7 +103,7 @@
from airflow.models.dag import get_dataset_triggered_next_run_info
from airflow.models.dagrun import RUN_ID_REGEX, DagRun, DagRunType
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import needs_expansion
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance, TaskInstanceNote
from airflow.providers_manager import ProvidersManager
Expand Down Expand Up @@ -140,7 +143,6 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session

from airflow.models.abstractoperator import AbstractOperator
from airflow.models.dag import DAG
from airflow.models.operator import Operator

Expand Down Expand Up @@ -294,7 +296,7 @@ def node_dict(node_id, label, node_class):
}


def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session):
def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session) -> dict[str, Any]:
"""
Create a nested dict representation of the DAG's TaskGroup and its children.
Expand All @@ -321,49 +323,35 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session):
.order_by(TaskInstance.task_id, TaskInstance.run_id)
)

grouped_tis = {task_id: list(tis) for task_id, tis in itertools.groupby(query, key=lambda ti: ti.task_id)}

sort_order = conf.get("webserver", "grid_view_sorting_order", fallback="topological")
if sort_order == "topological":

def sort_children_fn(task_group):
return task_group.topological_sort()

elif sort_order == "hierarchical_alphabetical":

def sort_children_fn(task_group):
return task_group.hierarchical_alphabetical_sort()
grouped_tis: dict[str, list[TaskInstance]] = collections.defaultdict(
list,
((task_id, list(tis)) for task_id, tis in itertools.groupby(query, key=lambda ti: ti.task_id)),
)

else:
@cache
def get_task_group_children_getter() -> operator.methodcaller:
sort_order = conf.get("webserver", "grid_view_sorting_order", fallback="topological")
if sort_order == "topological":
return operator.methodcaller("topological_sort")
if sort_order == "hierarchical_alphabetical":
return operator.methodcaller("hierarchical_alphabetical_sort")
raise AirflowConfigException(f"Unsupported grid_view_sorting_order: {sort_order}")

def task_group_to_grid(item, grouped_tis, *, is_parent_mapped: bool):
def task_group_to_grid(item: Operator | TaskGroup) -> dict[str, Any]:
if not isinstance(item, TaskGroup):

def _get_summary(task_instance):
return {
"task_id": task_instance.task_id,
"run_id": task_instance.run_id,
"state": task_instance.state,
"queued_dttm": task_instance.queued_dttm,
"start_date": task_instance.start_date,
"end_date": task_instance.end_date,
"try_number": wwwutils.get_try_count(task_instance._try_number, task_instance.state),
"note": task_instance.note,
}

def _mapped_summary(ti_summaries):
run_id = None
record = None
def _mapped_summary(ti_summaries: list[TaskInstance]) -> Iterator[dict[str, Any]]:
run_id = ""
record: dict[str, Any] = {}

def set_overall_state(record):
for state in wwwutils.priority:
if state in record["mapped_states"]:
record["state"] = state
break
if None in record["mapped_states"]:
# When turning the dict into JSON we can't have None as a key,
# so use the string that the UI does.
# When turning the dict into JSON we can't have None as a key,
# so use the string that the UI does.
with contextlib.suppress(KeyError):
record["mapped_states"]["no_status"] = record["mapped_states"].pop(None)

for ti_summary in ti_summaries:
Expand Down Expand Up @@ -403,10 +391,22 @@ def set_overall_state(record):
set_overall_state(record)
yield record

if isinstance(item, MappedOperator) or is_parent_mapped:
instances = list(_mapped_summary(grouped_tis.get(item.task_id, [])))
if item_is_mapped := needs_expansion(item):
instances = list(_mapped_summary(grouped_tis[item.task_id]))
else:
instances = list(map(_get_summary, grouped_tis.get(item.task_id, [])))
instances = [
{
"task_id": task_instance.task_id,
"run_id": task_instance.run_id,
"state": task_instance.state,
"queued_dttm": task_instance.queued_dttm,
"start_date": task_instance.start_date,
"end_date": task_instance.end_date,
"try_number": wwwutils.get_try_count(task_instance._try_number, task_instance.state),
"note": task_instance.note,
}
for task_instance in grouped_tis[item.task_id]
]

setup_teardown_type = {}
if item.is_setup is True:
Expand All @@ -419,7 +419,7 @@ def set_overall_state(record):
"instances": instances,
"label": item.label,
"extra_links": item.extra_links,
"is_mapped": isinstance(item, MappedOperator) or is_parent_mapped,
"is_mapped": item_is_mapped,
"has_outlet_datasets": any(isinstance(i, Dataset) for i in (item.outlets or [])),
"operator": item.operator_name,
"trigger_rule": item.trigger_rule,
Expand All @@ -428,12 +428,7 @@ def set_overall_state(record):

# Task Group
task_group = item
group_is_mapped = next(task_group.iter_mapped_task_groups(), None) is not None

children = [
task_group_to_grid(child, grouped_tis, is_parent_mapped=group_is_mapped)
for child in sort_children_fn(task_group)
]
children = [task_group_to_grid(child) for child in get_task_group_children_getter()(item)]

def get_summary(dag_run: DagRun):
child_instances = [
Expand Down Expand Up @@ -532,16 +527,14 @@ def get_mapped_group_summary(run_id: str, mapped_instances: Mapping[int, list[Ta
"instances": [],
}

if group_is_mapped:
mapped_group_summaries = get_mapped_group_summaries()

if next(task_group.iter_mapped_task_groups(), None) is not None:
return {
"id": task_group.group_id,
"label": task_group.label,
"children": children,
"tooltip": task_group.tooltip,
"instances": mapped_group_summaries,
"is_mapped": group_is_mapped,
"instances": get_mapped_group_summaries(),
"is_mapped": True,
}

group_summaries = [get_summary(dr) for dr in dag_runs]
Expand All @@ -554,7 +547,7 @@ def get_mapped_group_summary(run_id: str, mapped_instances: Mapping[int, list[Ta
"instances": group_summaries,
}

return task_group_to_grid(dag.task_group, grouped_tis, is_parent_mapped=False)
return task_group_to_grid(dag.task_group)


def get_key_paths(input_dict):
Expand Down Expand Up @@ -3375,7 +3368,7 @@ def extra_links(self, *, session: Session = NEW_SESSION):
if not dag or task_id not in dag.task_ids:
return {"url": None, "error": f"can't find dag {dag} or task_id {task_id}"}, 404

task: AbstractOperator = dag.get_task(task_id)
task = dag.get_task(task_id)
link_name = request.args.get("link_name")
if link_name is None:
return {"url": None, "error": "Link name not passed"}, 400
Expand Down

0 comments on commit 6498f67

Please sign in to comment.