Skip to content

Commit

Permalink
fix: replace dagTree with downstream_task_ids (apache#41587)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Aug 21, 2024
1 parent 83a6cb4 commit 86e12a9
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 184 deletions.
44 changes: 4 additions & 40 deletions airflow/providers/openlineage/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from airflow import __version__ as AIRFLOW_VERSION
from airflow.datasets import Dataset
from airflow.exceptions import AirflowProviderDeprecationWarning # TODO: move this maybe to Airflow's logic?
from airflow.models import DAG, BaseOperator, MappedOperator, Operator
from airflow.models import DAG, BaseOperator, MappedOperator
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.plugins.facets import (
AirflowDagRunFacet,
Expand Down Expand Up @@ -423,7 +423,7 @@ def get_airflow_job_facet(dag_run: DagRun) -> dict[str, AirflowJobFacet]:
return {}
return {
"airflow": AirflowJobFacet(
taskTree=_get_parsed_dag_tree(dag_run.dag),
taskTree={}, # caused OOM errors, to be removed, see #41587
taskGroups=_get_task_groups_details(dag_run.dag),
tasks=_get_tasks_details(dag_run.dag),
)
Expand All @@ -439,43 +439,6 @@ def get_airflow_state_run_facet(dag_run: DagRun) -> dict[str, AirflowStateRunFac
}


def _get_parsed_dag_tree(dag: DAG) -> dict:
"""
Get DAG's tasks hierarchy representation.
While the task dependencies are defined as following:
task >> [task_2, task_4] >> task_7
task_3 >> task_5
task_6 # has no dependencies, it's a root and a leaf
The result of this function will look like:
{
"task": {
"task_2": {
"task_7": {}
},
"task_4": {
"task_7": {}
}
},
"task_3": {
"task_5": {}
},
"task_6": {}
}
"""

def get_downstream(task: Operator, current_dict: dict):
current_dict[task.task_id] = {}
for tmp_task in sorted(task.downstream_list, key=lambda x: x.task_id):
get_downstream(tmp_task, current_dict[task.task_id])

task_dict: dict = {}
for t in sorted(dag.roots, key=lambda x: x.task_id):
get_downstream(t, task_dict)
return task_dict


def _get_tasks_details(dag: DAG) -> dict:
tasks = {
single_task.task_id: {
Expand All @@ -487,8 +450,9 @@ def _get_tasks_details(dag: DAG) -> dict:
"ui_label": single_task.label,
"is_setup": single_task.is_setup,
"is_teardown": single_task.is_teardown,
"downstream_task_ids": sorted(single_task.downstream_task_ids),
}
for single_task in dag.tasks
for single_task in sorted(dag.tasks, key=lambda x: x.task_id)
}

return tasks
Expand Down
Loading

0 comments on commit 86e12a9

Please sign in to comment.