Skip to content

Commit

Permalink
Add default for task on TaskInstance / fix attrs on TaskInstancePyd…
Browse files Browse the repository at this point in the history
…antic (apache#37854)

This was motivated by the need to fix serialization of TaskInstance to TaskInstancePydantic.  In some cases the data type was wrong; in other cases the attr was missing.  In the case of `task` it was more complicated.  Historically the value is type-declared but not set at init.  So it's not alway there.  And in the code base there was a lot of hasattr... The problem here though was that when serializing to TaskInstancePydantic, if the value was not set, then serialization would fail.  So we had to make it optional on TaskInstance and provide a default of None.
  • Loading branch information
dstandish authored Mar 19, 2024
1 parent d3ef673 commit c6bc052
Show file tree
Hide file tree
Showing 19 changed files with 172 additions and 28 deletions.
8 changes: 7 additions & 1 deletion airflow/example_dags/plugins/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def on_task_instance_running(previous_state: TaskInstanceState, task_instance: T

task = task_instance.task

if TYPE_CHECKING:
assert task

dag = task.dag
dag_name = None
if dag:
Expand Down Expand Up @@ -103,7 +106,10 @@ def on_task_instance_failed(previous_state: TaskInstanceState, task_instance: Ta

task = task_instance.task

dag = task_instance.task.dag
if TYPE_CHECKING:
assert task

dag = task.dag

print(f"Task start:{start_date} end:{end_date} duration:{duration}")
print(f"Task:{task} dag:{dag} dagrun:{dagrun}")
Expand Down
3 changes: 3 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def queue_task_instance(
cfg_path: str | None = None,
) -> None:
"""Queues task instance."""
if TYPE_CHECKING:
assert task_instance.task

pool = pool or task_instance.pool

command_list_to_run = task_instance.command_as_list(
Expand Down
3 changes: 3 additions & 0 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def queue_task_instance(
cfg_path: str | None = None,
) -> None:
"""Queues task instance with empty command because we do not need it."""
if TYPE_CHECKING:
assert task_instance.task

self.queue_command(
task_instance,
[str(task_instance)], # Just for better logging, it's not used anywhere
Expand Down
2 changes: 2 additions & 0 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:

self.task_instance.refresh_from_db()
ti = self.task_instance
if TYPE_CHECKING:
assert ti.task

if ti.state == TaskInstanceState.RUNNING:
fqdn = get_hostname()
Expand Down
5 changes: 4 additions & 1 deletion airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def get_backend() -> LineageBackend | None:


def _render_object(obj: Any, context: Context) -> dict:
return context["ti"].task.render_template(obj, context)
ti = context["ti"]
if TYPE_CHECKING:
assert ti.task
return ti.task.render_template(obj, context)


T = TypeVar("T", bound=Callable)
Expand Down
13 changes: 10 additions & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,9 +780,9 @@ def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
def should_schedule(self) -> bool:
return (
bool(self.tis)
and all(not t.task.depends_on_past for t in self.tis)
and all(t.task.max_active_tis_per_dag is None for t in self.tis)
and all(t.task.max_active_tis_per_dagrun is None for t in self.tis)
and all(not t.task.depends_on_past for t in self.tis) # type: ignore[union-attr]
and all(t.task.max_active_tis_per_dag is None for t in self.tis) # type: ignore[union-attr]
and all(t.task.max_active_tis_per_dagrun is None for t in self.tis) # type: ignore[union-attr]
and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
)

Expand Down Expand Up @@ -1020,6 +1020,9 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
If the ti does not need expansion, either because the task is not
mapped, or has already been expanded, *None* is returned.
"""
if TYPE_CHECKING:
assert ti.task

if ti.map_index >= 0: # Already expanded, we're good.
return None

Expand All @@ -1043,6 +1046,8 @@ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
# Set of task ids for which was already done _revise_map_indexes_if_mapped
revised_map_index_task_ids = set()
for schedulable in itertools.chain(schedulable_tis, additional_tis):
if TYPE_CHECKING:
assert schedulable.task
old_state = schedulable.state
if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
old_states[schedulable.key] = old_state
Expand Down Expand Up @@ -1525,6 +1530,8 @@ def schedule_tis(
dummy_ti_ids = []
schedulable_ti_ids = []
for ti in schedulable_tis:
if TYPE_CHECKING:
assert ti.task
if (
ti.task.inherits_from_empty_operator
and not ti.task.on_execute_callback
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def __init__(self, ti: TaskInstance, render_templates=True):
self.ti = ti
if render_templates:
ti.render_templates()

if TYPE_CHECKING:
assert ti.task

self.task = ti.task
if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None):
# we can safely import it here from provider. In Airflow 2.7.0+ you need to have new version
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,12 @@ def skip_all_except(
dag_run = ti.get_dagrun()
if TYPE_CHECKING:
assert isinstance(dag_run, DagRun)
assert ti.task

# TODO(potiuk): Handle TaskInstancePydantic case differently - we need to figure out the way to
# pass task that has been set in LocalTaskJob but in the way that TaskInstancePydantic definition
# does not attempt to serialize the field from/to ORM
task = ti.task # type: ignore[union-attr]
task = ti.task
dag = task.dag
if TYPE_CHECKING:
assert dag
Expand Down
77 changes: 72 additions & 5 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance | TaskInstancePydantic,
raise ValueError("``task_instance`` must have ``dag_run`` set")
tis = task_instance.dag_run.get_task_instances(session=session)
if TYPE_CHECKING:
assert task_instance.task
assert isinstance(task_instance.task.dag, DAG)

for ti in tis:
Expand Down Expand Up @@ -268,6 +269,8 @@ def clear_task_instances(
if ti_dag and ti_dag.has_task(task_id):
task = ti_dag.get_task(task_id)
ti.refresh_from_task(task)
if TYPE_CHECKING:
assert ti.task
task_retries = task.retries
ti.max_tries = ti.try_number + task_retries - 1
else:
Expand Down Expand Up @@ -396,6 +399,9 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
"""
task_to_execute = task_instance.task

if TYPE_CHECKING:
assert task_to_execute

if isinstance(task_to_execute, MappedOperator):
raise AirflowException("MappedOperator cannot be executed.")

Expand Down Expand Up @@ -601,6 +607,7 @@ def _get_template_context(

task = task_instance.task
if TYPE_CHECKING:
assert task
assert task.dag
dag: DAG = task.dag

Expand Down Expand Up @@ -814,6 +821,9 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic)
# Couldn't load the task, don't know number of retries, guess:
return task_instance.try_number <= task_instance.max_tries

if TYPE_CHECKING:
assert task_instance.task

return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries


Expand Down Expand Up @@ -977,6 +987,9 @@ def _get_previous_dagrun(
:meta private:
"""
if TYPE_CHECKING:
assert task_instance.task

dag = task_instance.task.dag
if dag is None:
return None
Expand Down Expand Up @@ -1104,6 +1117,9 @@ def _get_email_subject_content(
html_content_err = jinja_env.from_string(default_html_content_err).render(**default_context)

else:
if TYPE_CHECKING:
assert task_instance.task

# Use the DAG's get_template_env() to set force_sandboxed. Don't add
# the flag to the function on task object -- that function can be
# overridden, and adding a flag breaks backward compatibility.
Expand Down Expand Up @@ -1338,9 +1354,11 @@ class TaskInstance(Base, LoggingMixin):
cascade="all, delete, delete-orphan",
)
note = association_proxy("task_instance_note", "content", creator=_creator_note)
task: Operator # Not always set...
task: Operator | None = None
test_mode: bool = False
is_trigger_log_context: bool = False
run_as_user: str | None = None
raw: bool | None = None
"""Indicate to FileTaskHandler that logging context should be set up for trigger logging.
:meta private:
Expand All @@ -1360,6 +1378,9 @@ def __init__(
self.task_id = task.task_id
self.map_index = map_index
self.refresh_from_task(task)
if TYPE_CHECKING:
assert self.task

# init_on_load will config the log
self.init_on_load()

Expand Down Expand Up @@ -1524,7 +1545,9 @@ def _command_as_list(
) -> list[str]:
dag: DAG | DagModel | DagModelPydantic | None
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
if hasattr(ti, "task") and hasattr(ti.task, "dag") and ti.task.dag is not None:
if hasattr(ti, "task") and getattr(ti.task, "dag", None) is not None:
if TYPE_CHECKING:
assert ti.task
dag = ti.task.dag
else:
dag = ti.dag_model
Expand Down Expand Up @@ -1858,6 +1881,8 @@ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
:param session: SQLAlchemy ORM Session
"""
task = self.task
if TYPE_CHECKING:
assert task

if not task.downstream_task_ids:
return True
Expand Down Expand Up @@ -2014,6 +2039,9 @@ def are_dependencies_met(
@provide_session
def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
"""Get failed Dependencies."""
if TYPE_CHECKING:
assert self.task

dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
for dep_status in dep.get_dep_statuses(self, session, dep_context):
Expand Down Expand Up @@ -2090,14 +2118,19 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
"""
info = inspect(self)
if info.attrs.dag_run.loaded_value is not NO_VALUE:
if hasattr(self, "task"):
if getattr(self, "task", None) is not None:
if TYPE_CHECKING:
assert self.task
self.dag_run.dag = self.task.dag
return self.dag_run

from airflow.models.dagrun import DagRun # Avoid circular import

dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
if hasattr(self, "task"):
if getattr(self, "task", None) is not None:
if TYPE_CHECKING:
assert self.task

dr.dag = self.task.dag
# Record it in the instance for next time. This means that `self.execution_date` will work correctly
set_committed_value(self, "dag_run", dr)
Expand Down Expand Up @@ -2145,6 +2178,9 @@ def _check_and_change_state_before_execution(
:param session: SQLAlchemy ORM Session
:return: whether the state was changed to running or not
"""
if TYPE_CHECKING:
assert task_instance.task

if isinstance(task_instance, TaskInstance):
ti: TaskInstance = task_instance
else: # isinstance(task_instance,TaskInstancePydantic)
Expand Down Expand Up @@ -2357,9 +2393,13 @@ def _run_raw_task(
:param pool: specifies the pool to use to run the task instance
:param session: SQLAlchemy ORM Session
"""
if TYPE_CHECKING:
assert self.task

self.test_mode = test_mode
self.refresh_from_task(self.task, pool_override=pool)
self.refresh_from_db(session=session)

self.job_id = job_id
self.hostname = get_hostname()
self.pid = os.getpid()
Expand Down Expand Up @@ -2488,6 +2528,9 @@ def _run_raw_task(
return None

def _register_dataset_changes(self, *, session: Session) -> None:
if TYPE_CHECKING:
assert self.task

for obj in self.task.outlets or []:
self.log.debug("outlet obj %s", obj)
# Lineage can have other types of objects besides datasets
Expand All @@ -2502,6 +2545,9 @@ def _execute_task_with_callbacks(self, context: Context, test_mode: bool = False
"""Prepare Task for Execution."""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

if TYPE_CHECKING:
assert self.task

parent_pid = os.getpid()

def signal_handler(signum, frame):
Expand Down Expand Up @@ -2603,6 +2649,9 @@ def defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""
from airflow.models.trigger import Trigger

if TYPE_CHECKING:
assert self.task

# First, make the trigger entry
trigger_row = Trigger.from_object(defer.trigger)
session.add(trigger_row)
Expand Down Expand Up @@ -2689,6 +2738,9 @@ def run(

def dry_run(self) -> None:
"""Only Renders Templates for the TI."""
if TYPE_CHECKING:
assert self.task

self.task = self.task.prepare_for_execution()
self.render_templates()
if TYPE_CHECKING:
Expand All @@ -2711,6 +2763,9 @@ def _handle_reschedule(

self.refresh_from_db(session)

if TYPE_CHECKING:
assert self.task

self.end_date = timezone.utcnow()
self.set_duration()

Expand Down Expand Up @@ -2832,6 +2887,8 @@ def fetch_handle_failure_context(
task: BaseOperator | None = None
try:
if getattr(ti, "task", None) and context:
if TYPE_CHECKING:
assert ti.task
task = ti.task.unmap((context, session))
except Exception:
cls.logger().error("Unable to unmap task to determine if we need to send an alert email")
Expand Down Expand Up @@ -2923,6 +2980,9 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields

if TYPE_CHECKING:
assert self.task

rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
if rendered_task_instance_fields:
self.task = self.task.unmap(None)
Expand Down Expand Up @@ -2965,6 +3025,9 @@ def render_templates(
context = self.get_template_context()
original_task = self.task

if TYPE_CHECKING:
assert original_task

# If self.task is mapped, this call replaces self.task to point to the
# unmapped BaseOperator created by this function! This is because the
# MappedOperator is useless for template rendering, and we need to be
Expand Down Expand Up @@ -3336,6 +3399,7 @@ def _schedule_downstream_tasks(

task = ti.task
if TYPE_CHECKING:
assert task
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
Expand Down Expand Up @@ -3367,7 +3431,7 @@ def _schedule_downstream_tasks(
)
]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
if getattr(schedulable_ti, "task", None) is None:
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis, session=session, max_tis_per_query=max_tis_per_query)
Expand Down Expand Up @@ -3442,6 +3506,9 @@ def tg2(inp):
:return: Specific map index or map indexes to pull, or ``None`` if we
want to "whole" return value (i.e. no mapped task groups involved).
"""
if TYPE_CHECKING:
assert self.task

# This value should never be None since we already know the current task
# is in a mapped task group, and should have been expanded, despite that,
# we need to check that it is not None to satisfy Mypy.
Expand Down
Loading

0 comments on commit c6bc052

Please sign in to comment.