diff --git a/airflow/api_internal/endpoints/rpc_api_endpoint.py b/airflow/api_internal/endpoints/rpc_api_endpoint.py index 1712e43d7f068..1820d63194106 100644 --- a/airflow/api_internal/endpoints/rpc_api_endpoint.py +++ b/airflow/api_internal/endpoints/rpc_api_endpoint.py @@ -26,7 +26,6 @@ from flask import Response from airflow.jobs.job import Job, most_recent_job -from airflow.models.taskinstance import _get_template_context, _update_rtif from airflow.sensors.base import _orig_start_date from airflow.serialization.serialized_objects import BaseSerialization from airflow.utils.session import create_session @@ -42,22 +41,37 @@ def _initialize_map() -> dict[str, Callable]: from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor + from airflow.datasets.manager import DatasetManager from airflow.models import Trigger, Variable, XCom from airflow.models.dag import DAG, DagModel from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning from airflow.models.serialized_dag import SerializedDagModel - from airflow.models.taskinstance import TaskInstance + from airflow.models.taskinstance import ( + TaskInstance, + _add_log, + _defer_task, + _get_template_context, + _handle_failure, + _handle_reschedule, + _update_rtif, + _xcom_pull, + ) from airflow.secrets.metastore import MetastoreBackend from airflow.utils.cli_action_loggers import _default_action_log_internal from airflow.utils.log.file_task_handler import FileTaskHandler functions: list[Callable] = [ _default_action_log_internal, + _defer_task, _get_template_context, _get_ti_db_access, _update_rtif, _orig_start_date, + _handle_failure, + _handle_reschedule, + _add_log, + _xcom_pull, DagFileProcessor.update_import_errors, DagFileProcessor.manage_slas, DagFileProcessorManager.deactivate_stale_dags, @@ -66,6 +80,7 @@ def _initialize_map() -> dict[str, Callable]: DagModel.get_current, DagFileProcessorManager.clear_nonexistent_import_errors, DagWarning.purge_inactive_dag_warnings, + DatasetManager.register_dataset_change, FileTaskHandler._render_filename_db_access, Job._add_to_db, Job._fetch_from_db, @@ -79,6 +94,7 @@ def _initialize_map() -> dict[str, Callable]: XCom.get_one, XCom.get_many, XCom.clear, + XCom.set, Variable.set, Variable.update, Variable.delete, @@ -94,7 +110,6 @@ def _initialize_map() -> dict[str, Callable]: TaskInstance.get_task_instance, TaskInstance._get_dagrun, TaskInstance._set_state, - TaskInstance.fetch_handle_failure_context, TaskInstance.save_to_db, TaskInstance._schedule_downstream_tasks, TaskInstance._clear_xcom_data, diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index abc29b0ddf671..ac9b211c21798 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -673,7 +673,7 @@ def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> N else: ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True, raise_on_defer=True) except TaskDeferred as defer: - ti.defer_task(defer=defer, session=session) + ti.defer_task(exception=defer, session=session) log.info("[TASK TEST] running trigger in line") event = _run_inline_trigger(defer.trigger) diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py index 0a6aad50a0f89..d3e7a8cf84bf2 100644 --- a/airflow/datasets/manager.py +++ b/airflow/datasets/manager.py @@ -22,12 +22,14 @@ from sqlalchemy import exc, select from sqlalchemy.orm import joinedload +from airflow.api_internal.internal_api_call import internal_api_call from airflow.configuration import conf from airflow.datasets import Dataset from airflow.listeners.listener import get_listener_manager from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -56,13 +58,16 @@ def create_datasets(self, dataset_models: list[DatasetModel], session: Session) for dataset_model in dataset_models: self.notify_dataset_created(dataset=Dataset(uri=dataset_model.uri, extra=dataset_model.extra)) + @classmethod + @internal_api_call + @provide_session def register_dataset_change( - self, + cls, *, task_instance: TaskInstance | None = None, dataset: Dataset, extra=None, - session: Session, + session: Session = NEW_SESSION, **kwargs, ) -> DatasetEvent | None: """ @@ -71,13 +76,14 @@ def register_dataset_change( For local datasets, look them up, record the dataset event, queue dagruns, and broadcast the dataset event """ + # todo: add test so that all usages of internal_api_call are added to rpc endpoint dataset_model = session.scalar( select(DatasetModel) .where(DatasetModel.uri == dataset.uri) .options(joinedload(DatasetModel.consuming_dags).joinedload(DagScheduleDatasetReference.dag)) ) if not dataset_model: - self.log.warning("DatasetModel %s not found", dataset) + cls.logger().warning("DatasetModel %s not found", dataset) return None event_kwargs = { @@ -97,10 +103,10 @@ def register_dataset_change( session.add(dataset_event) session.flush() - self.notify_dataset_changed(dataset=dataset) + cls.notify_dataset_changed(dataset=dataset) Stats.incr("dataset.updates") - self._queue_dagruns(dataset_model, session) + cls._queue_dagruns(dataset_model, session) session.flush() return dataset_event @@ -108,11 +114,13 @@ def notify_dataset_created(self, dataset: Dataset): """Run applicable notification actions when a dataset is created.""" get_listener_manager().hook.on_dataset_created(dataset=dataset) - def notify_dataset_changed(self, dataset: Dataset): + @classmethod + def notify_dataset_changed(cls, dataset: Dataset): """Run applicable notification actions when a dataset is changed.""" get_listener_manager().hook.on_dataset_changed(dataset=dataset) - def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + @classmethod + def _queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None: # Possible race condition: if multiple dags or multiple (usually # mapped) tasks update the same dataset, this can fail with a unique # constraint violation. @@ -123,10 +131,11 @@ def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: # where `ti.state` is changed. if session.bind.dialect.name == "postgresql": - return self._postgres_queue_dagruns(dataset, session) - return self._slow_path_queue_dagruns(dataset, session) + return cls._postgres_queue_dagruns(dataset, session) + return cls._slow_path_queue_dagruns(dataset, session) - def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + @classmethod + def _slow_path_queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None: def _queue_dagrun_if_needed(dag: DagModel) -> str | None: if not dag.is_active or dag.is_paused: return None @@ -137,14 +146,15 @@ def _queue_dagrun_if_needed(dag: DagModel) -> str | None: with session.begin_nested(): session.merge(item) except exc.IntegrityError: - self.log.debug("Skipping record %s", item, exc_info=True) + cls.logger().debug("Skipping record %s", item, exc_info=True) return dag.dag_id queued_results = (_queue_dagrun_if_needed(ref.dag) for ref in dataset.consuming_dags) if queued_dag_ids := [r for r in queued_results if r is not None]: - self.log.debug("consuming dag ids %s", queued_dag_ids) + cls.logger().debug("consuming dag ids %s", queued_dag_ids) - def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None: + @classmethod + def _postgres_queue_dagruns(cls, dataset: DatasetModel, session: Session) -> None: from sqlalchemy.dialects.postgresql import insert values = [ diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 84076ee70c53e..008608b96d225 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -1548,7 +1548,7 @@ def schedule_tis( if ti.state != TaskInstanceState.UP_FOR_RESCHEDULE: ti.try_number += 1 ti.defer_task( - defer=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method), + exception=TaskDeferred(trigger=ti.task.start_trigger, method_name=ti.task.next_method), session=session, ) else: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index a27579b05e752..1b2485f7025e3 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -187,6 +187,191 @@ class TaskReturnCode(Enum): """When task exits with deferral to trigger.""" +@internal_api_call +@provide_session +def _merge_ti(ti, session: Session = NEW_SESSION): + session.merge(ti) + session.commit() + + +@internal_api_call +@provide_session +def _add_log( + event, + task_instance=None, + owner=None, + owner_display_name=None, + extra=None, + session: Session = NEW_SESSION, + **kwargs, +): + session.add( + Log( + event, + task_instance, + owner, + owner_display_name, + extra, + **kwargs, + ) + ) + + +def _run_raw_task( + ti: TaskInstance | TaskInstancePydantic, + mark_success: bool = False, + test_mode: bool = False, + job_id: str | None = None, + pool: str | None = None, + raise_on_defer: bool = False, + session: Session | None = None, +) -> TaskReturnCode | None: + """ + Run a task, update the state upon completion, and run any appropriate callbacks. + + Immediately runs the task (without checking or changing db state + before execution) and then sets the appropriate final state after + completion and runs any post-execute callbacks. Meant to be called + only after another function changes the state to running. + + :param mark_success: Don't run the task, mark its state as success + :param test_mode: Doesn't record success or failure in the DB + :param pool: specifies the pool to use to run the task instance + :param session: SQLAlchemy ORM Session + """ + if TYPE_CHECKING: + assert ti.task + + ti.test_mode = test_mode + ti.refresh_from_task(ti.task, pool_override=pool) + ti.refresh_from_db(session=session) + + ti.job_id = job_id + ti.hostname = get_hostname() + ti.pid = os.getpid() + if not test_mode: + TaskInstance.save_to_db(ti=ti, session=session) + actual_start_date = timezone.utcnow() + Stats.incr(f"ti.start.{ti.task.dag_id}.{ti.task.task_id}", tags=ti.stats_tags) + # Same metric with tagging + Stats.incr("ti.start", tags=ti.stats_tags) + # Initialize final state counters at zero + for state in State.task_states: + Stats.incr( + f"ti.finish.{ti.task.dag_id}.{ti.task.task_id}.{state}", + count=0, + tags=ti.stats_tags, + ) + # Same metric with tagging + Stats.incr( + "ti.finish", + count=0, + tags={**ti.stats_tags, "state": str(state)}, + ) + with set_current_task_instance_session(session=session): + ti.task = ti.task.prepare_for_execution() + context = ti.get_template_context(ignore_param_exceptions=False, session=session) + + try: + if not mark_success: + TaskInstance._execute_task_with_callbacks( + self=ti, # type: ignore[arg-type] + context=context, + test_mode=test_mode, + session=session, + ) + if not test_mode: + ti.refresh_from_db(lock_for_update=True, session=session) + ti.state = TaskInstanceState.SUCCESS + except TaskDeferred as defer: + # The task has signalled it wants to defer execution based on + # a trigger. + if raise_on_defer: + raise + ti.defer_task(exception=defer, session=session) + ti.log.info( + "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", + ti.dag_id, + ti.task_id, + ti.run_id, + _date_or_empty(task_instance=ti, attr="execution_date"), + _date_or_empty(task_instance=ti, attr="start_date"), + ) + return TaskReturnCode.DEFERRED + except AirflowSkipException as e: + # Recording SKIP + # log only if exception has any arguments to prevent log flooding + if e.args: + ti.log.info(e) + if not test_mode: + ti.refresh_from_db(lock_for_update=True, session=session) + ti.state = TaskInstanceState.SKIPPED + _run_finished_callback(callbacks=ti.task.on_skipped_callback, context=context) + TaskInstance.save_to_db(ti=ti, session=session) + except AirflowRescheduleException as reschedule_exception: + ti._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) + ti.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") + return None + except (AirflowFailException, AirflowSensorTimeout) as e: + # If AirflowFailException is raised, task should not retry. + # If a sensor in reschedule mode reaches timeout, task should not retry. + ti.handle_failure(e, test_mode, context, force_fail=True, session=session) # already saves to db + raise + except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: + if not test_mode: + ti.refresh_from_db(lock_for_update=True, session=session) + # for case when task is marked as success/failed externally + # or dagrun timed out and task is marked as skipped + # current behavior doesn't hit the callbacks + if ti.state in State.finished: + ti.clear_next_method_args() + TaskInstance.save_to_db(ti=ti, session=session) + return None + else: + ti.handle_failure(e, test_mode, context, session=session) + raise + except SystemExit as e: + # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. + # Therefore, here we must handle only error codes. + msg = f"Task failed due to SystemExit({e.code})" + ti.handle_failure(msg, test_mode, context, session=session) + raise AirflowException(msg) + except BaseException as e: + ti.handle_failure(e, test_mode, context, session=session) + raise + finally: + Stats.incr( + f"ti.finish.{ti.dag_id}.{ti.task_id}.{ti.state}", + tags=ti.stats_tags, + ) + # Same metric with tagging + Stats.incr("ti.finish", tags={**ti.stats_tags, "state": str(ti.state)}) + + # Recording SKIPPED or SUCCESS + ti.clear_next_method_args() + ti.end_date = timezone.utcnow() + _log_state(task_instance=ti) + ti.set_duration() + + # run on_success_callback before db committing + # otherwise, the LocalTaskJob sees the state is changed to `success`, + # but the task_runner is still running, LocalTaskJob then treats the state is set externally! + _run_finished_callback(callbacks=ti.task.on_success_callback, context=context) + + if not test_mode: + _add_log(event=ti.state, task_instance=ti, session=session) + if ti.state == TaskInstanceState.SUCCESS: + ti._register_dataset_changes(events=context["outlet_events"], session=session) + + TaskInstance.save_to_db(ti=ti, session=session) + if ti.state == TaskInstanceState.SUCCESS: + get_listener_manager().hook.on_task_instance_success( + previous_state=TaskInstanceState.RUNNING, task_instance=ti, session=session + ) + + return None + + @contextlib.contextmanager def set_current_context(context: Context) -> Generator[Context, None, None]: """ @@ -374,6 +559,108 @@ def clear_task_instances( session.flush() +@internal_api_call +@provide_session +def _xcom_pull( + *, + ti, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + map_indexes: int | Iterable[int] | None = None, + default: Any = None, +) -> Any: + """Pull XComs that optionally meet certain criteria. + + :param key: A key for the XCom. If provided, only XComs with matching + keys will be returned. The default key is ``'return_value'``, also + available as constant ``XCOM_RETURN_KEY``. This key is automatically + given to XComs returned by tasks (as opposed to being pushed + manually). To remove the filter, pass *None*. + :param task_ids: Only XComs from tasks with matching ids will be + pulled. Pass *None* to remove the filter. + :param dag_id: If provided, only pulls XComs from this DAG. If *None* + (default), the DAG of the calling task is used. + :param map_indexes: If provided, only pull XComs with matching indexes. + If *None* (default), this is inferred from the task(s) being pulled + (see below for details). + :param include_prior_dates: If False, only XComs from the current + execution_date are returned. If *True*, XComs from previous dates + are returned as well. + + When pulling one single task (``task_id`` is *None* or a str) without + specifying ``map_indexes``, the return value is inferred from whether + the specified task is mapped. If not, value from the one single task + instance is returned. If the task to pull is mapped, an iterator (not a + list) yielding XComs from mapped task instances is returned. In either + case, ``default`` (*None* if not specified) is returned if no matching + XComs are found. + + When pulling multiple tasks (i.e. either ``task_id`` or ``map_index`` is + a non-str iterable), a list of matching XComs is returned. Elements in + the list is ordered by item ordering in ``task_id`` and ``map_index``. + """ + if dag_id is None: + dag_id = ti.dag_id + + query = XCom.get_many( + key=key, + run_id=ti.run_id, + dag_ids=dag_id, + task_ids=task_ids, + map_indexes=map_indexes, + include_prior_dates=include_prior_dates, + session=session, + ) + + # NOTE: Since we're only fetching the value field and not the whole + # class, the @recreate annotation does not kick in. Therefore we need to + # call XCom.deserialize_value() manually. + + # We are only pulling one single task. + if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): + first = query.with_entities( + XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value + ).first() + if first is None: # No matching XCom at all. + return default + if map_indexes is not None or first.map_index < 0: + return XCom.deserialize_value(first) + return LazyXComSelectSequence.from_select( + query.with_entities(XCom.value).order_by(None).statement, + order_by=[XCom.map_index], + session=session, + ) + + # At this point either task_ids or map_indexes is explicitly multi-value. + # Order return values to match task_ids and map_indexes ordering. + ordering = [] + if task_ids is None or isinstance(task_ids, str): + ordering.append(XCom.task_id) + elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}: + ordering.append(case(task_id_whens, value=XCom.task_id)) + else: + ordering.append(XCom.task_id) + if map_indexes is None or isinstance(map_indexes, int): + ordering.append(XCom.map_index) + elif isinstance(map_indexes, range): + order = XCom.map_index + if map_indexes.step < 0: + order = order.desc() + ordering.append(order) + elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}: + ordering.append(case(map_index_whens, value=XCom.map_index)) + else: + ordering.append(XCom.map_index) + return LazyXComSelectSequence.from_select( + query.with_entities(XCom.value).order_by(None).statement, + order_by=ordering, + session=session, + ) + + def _is_mappable_value(value: Any) -> TypeGuard[Collection]: """Whether a value can be used for task mapping. @@ -504,6 +791,34 @@ def _execute_callable(context: Context, **execute_callable_kwargs): return result +def _set_ti_attrs(target, source): + # Fields ordered per model definition + target.start_date = source.start_date + target.end_date = source.end_date + target.duration = source.duration + target.state = source.state + target.try_number = source.try_number + target.max_tries = source.max_tries + target.hostname = source.hostname + target.unixname = source.unixname + target.job_id = source.job_id + target.pool = source.pool + target.pool_slots = source.pool_slots or 1 + target.queue = source.queue + target.priority_weight = source.priority_weight + target.operator = source.operator + target.custom_operator_name = source.custom_operator_name + target.queued_dttm = source.queued_dttm + target.queued_by_job_id = source.queued_by_job_id + target.pid = source.pid + target.executor = source.executor + target.executor_config = source.executor_config + target.external_executor_id = source.external_executor_id + target.trigger_id = source.trigger_id + target.next_method = source.next_method + target.next_kwargs = source.next_kwargs + + def _refresh_from_db( *, task_instance: TaskInstance | TaskInstancePydantic, @@ -534,31 +849,7 @@ def _refresh_from_db( ) if ti: - # Fields ordered per model definition - task_instance.start_date = ti.start_date - task_instance.end_date = ti.end_date - task_instance.duration = ti.duration - task_instance.state = ti.state - task_instance.try_number = ti.try_number - task_instance.max_tries = ti.max_tries - task_instance.hostname = ti.hostname - task_instance.unixname = ti.unixname - task_instance.job_id = ti.job_id - task_instance.pool = ti.pool - task_instance.pool_slots = ti.pool_slots or 1 - task_instance.queue = ti.queue - task_instance.priority_weight = ti.priority_weight - task_instance.operator = ti.operator - task_instance.custom_operator_name = ti.custom_operator_name - task_instance.queued_dttm = ti.queued_dttm - task_instance.queued_by_job_id = ti.queued_by_job_id - task_instance.pid = ti.pid - task_instance.executor = ti.executor - task_instance.executor_config = ti.executor_config - task_instance.external_executor_id = ti.external_executor_id - task_instance.trigger_id = ti.trigger_id - task_instance.next_method = ti.next_method - task_instance.next_kwargs = ti.next_kwargs + _set_ti_attrs(task_instance, ti) else: task_instance.state = None @@ -872,6 +1163,8 @@ def _is_eligible_to_retry(*, task_instance: TaskInstance | TaskInstancePydantic) return task_instance.task.retries and task_instance.try_number <= task_instance.max_tries +@provide_session +@internal_api_call def _handle_failure( *, task_instance: TaskInstance | TaskInstancePydantic, @@ -896,9 +1189,9 @@ def _handle_failure( """ if test_mode is None: test_mode = task_instance.test_mode - + task_instance = _coalesce_to_orm_ti(ti=task_instance, session=session) failure_context = TaskInstance.fetch_handle_failure_context( - ti=task_instance, + ti=task_instance, # type: ignore[arg-type] error=error, test_mode=test_mode, context=context, @@ -1265,6 +1558,132 @@ def _update_rtif(ti, rendered_fields, session: Session | None = None): RenderedTaskInstanceFields.delete_old_records(ti.task_id, ti.dag_id, session=session) +def _coalesce_to_orm_ti(*, ti: TaskInstancePydantic | TaskInstance, session: Session): + from airflow.models.dagrun import DagRun + from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic + + if isinstance(ti, TaskInstancePydantic): + orm_ti = DagRun.fetch_task_instance( + dag_id=ti.dag_id, + dag_run_id=ti.run_id, + task_id=ti.task_id, + map_index=ti.map_index, + session=session, + ) + if TYPE_CHECKING: + assert orm_ti + ti, pydantic_ti = orm_ti, ti + _set_ti_attrs(ti, pydantic_ti) + ti.task = pydantic_ti.task + return ti + + +@internal_api_call +@provide_session +def _defer_task( + ti: TaskInstance | TaskInstancePydantic, exception: TaskDeferred, session: Session = NEW_SESSION +) -> TaskInstancePydantic | TaskInstance: + from airflow.models.trigger import Trigger + + # First, make the trigger entry + trigger_row = Trigger.from_object(exception.trigger) + session.add(trigger_row) + session.flush() + + ti = _coalesce_to_orm_ti(ti=ti, session=session) # ensure orm obj in case it's pydantic + + if TYPE_CHECKING: + assert ti.task + + # Then, update ourselves so it matches the deferral request + # Keep an eye on the logic in `check_and_change_state_before_execution()` + # depending on self.next_method semantics + ti.state = TaskInstanceState.DEFERRED + ti.trigger_id = trigger_row.id + ti.next_method = exception.method_name + ti.next_kwargs = exception.kwargs or {} + + # Calculate timeout too if it was passed + if exception.timeout is not None: + ti.trigger_timeout = timezone.utcnow() + exception.timeout + else: + ti.trigger_timeout = None + + # If an execution_timeout is set, set the timeout to the minimum of + # it and the trigger timeout + execution_timeout = ti.task.execution_timeout + if execution_timeout: + if TYPE_CHECKING: + assert ti.start_date + if ti.trigger_timeout: + ti.trigger_timeout = min(ti.start_date + execution_timeout, ti.trigger_timeout) + else: + ti.trigger_timeout = ti.start_date + execution_timeout + if ti.test_mode: + _add_log(event=ti.state, task_instance=ti, session=session) + session.merge(ti) + session.commit() + return ti + + +@internal_api_call +@provide_session +def _handle_reschedule( + ti, + actual_start_date: datetime, + reschedule_exception: AirflowRescheduleException, + test_mode: bool = False, + session: Session = NEW_SESSION, +): + # Don't record reschedule request in test mode + if test_mode: + return + + ti = _coalesce_to_orm_ti(ti=ti, session=session) + + from airflow.models.dagrun import DagRun # Avoid circular import + + ti.refresh_from_db(session) + + if TYPE_CHECKING: + assert ti.task + + ti.end_date = timezone.utcnow() + ti.set_duration() + + # Lock DAG run to be sure not to get into a deadlock situation when trying to insert + # TaskReschedule which apparently also creates lock on corresponding DagRun entity + with_row_locks( + session.query(DagRun).filter_by( + dag_id=ti.dag_id, + run_id=ti.run_id, + ), + session=session, + ).one() + # Log reschedule request + session.add( + TaskReschedule( + ti.task_id, + ti.dag_id, + ti.run_id, + ti.try_number, + actual_start_date, + ti.end_date, + reschedule_exception.reschedule_date, + ti.map_index, + ) + ) + + # set state + ti.state = TaskInstanceState.UP_FOR_RESCHEDULE + + ti.clear_next_method_args() + + session.merge(ti) + session.commit() + return ti + + class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. @@ -2452,137 +2871,15 @@ def _run_raw_task( if TYPE_CHECKING: assert self.task - self.test_mode = test_mode - self.refresh_from_task(self.task, pool_override=pool) - self.refresh_from_db(session=session) - - self.job_id = job_id - self.hostname = get_hostname() - self.pid = os.getpid() - if not test_mode: - session.merge(self) - session.commit() - actual_start_date = timezone.utcnow() - Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.start", tags=self.stats_tags) - # Initialize final state counters at zero - for state in State.task_states: - Stats.incr( - f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", - count=0, - tags=self.stats_tags, - ) - # Same metric with tagging - Stats.incr( - "ti.finish", - count=0, - tags={**self.stats_tags, "state": str(state)}, - ) - with set_current_task_instance_session(session=session): - self.task = self.task.prepare_for_execution() - context = self.get_template_context(ignore_param_exceptions=False) - - try: - if not mark_success: - self._execute_task_with_callbacks(context, test_mode, session=session) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - self.state = TaskInstanceState.SUCCESS - except TaskDeferred as defer: - # The task has signalled it wants to defer execution based on - # a trigger. - if raise_on_defer: - raise - self.defer_task(defer=defer, session=session) - self.log.info( - "Pausing task as DEFERRED. dag_id=%s, task_id=%s, run_id=%s, execution_date=%s, start_date=%s", - self.dag_id, - self.task_id, - self.run_id, - _date_or_empty(task_instance=self, attr="execution_date"), - _date_or_empty(task_instance=self, attr="start_date"), - ) - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self) - session.commit() - return TaskReturnCode.DEFERRED - except AirflowSkipException as e: - # Recording SKIP - # log only if exception has any arguments to prevent log flooding - if e.args: - self.log.info(e) - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - _run_finished_callback(callbacks=self.task.on_skipped_callback, context=context) - session.commit() - self.state = TaskInstanceState.SKIPPED - except AirflowRescheduleException as reschedule_exception: - self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, session=session) - session.commit() - return None - except (AirflowFailException, AirflowSensorTimeout) as e: - # If AirflowFailException is raised, task should not retry. - # If a sensor in reschedule mode reaches timeout, task should not retry. - self.handle_failure(e, test_mode, context, force_fail=True, session=session) - session.commit() - raise - except (AirflowTaskTimeout, AirflowException, AirflowTaskTerminated) as e: - if not test_mode: - self.refresh_from_db(lock_for_update=True, session=session) - # for case when task is marked as success/failed externally - # or dagrun timed out and task is marked as skipped - # current behavior doesn't hit the callbacks - if self.state in State.finished: - self.clear_next_method_args() - session.merge(self) - session.commit() - return None - else: - self.handle_failure(e, test_mode, context, session=session) - session.commit() - raise - except SystemExit as e: - # We have already handled SystemExit with success codes (0 and None) in the `_execute_task`. - # Therefore, here we must handle only error codes. - msg = f"Task failed due to SystemExit({e.code})" - self.handle_failure(msg, test_mode, context, session=session) - session.commit() - raise AirflowException(msg) - except BaseException as e: - self.handle_failure(e, test_mode, context, session=session) - session.commit() - raise - finally: - Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}", tags=self.stats_tags) - # Same metric with tagging - Stats.incr("ti.finish", tags={**self.stats_tags, "state": str(self.state)}) - - # Recording SKIPPED or SUCCESS - self.clear_next_method_args() - self.end_date = timezone.utcnow() - _log_state(task_instance=self) - self.set_duration() - - # run on_success_callback before db committing - # otherwise, the LocalTaskJob sees the state is changed to `success`, - # but the task_runner is still running, LocalTaskJob then treats the state is set externally! - _run_finished_callback(callbacks=self.task.on_success_callback, context=context) - - if not test_mode: - session.add(Log(self.state, self)) - session.merge(self).task = self.task - if self.state == TaskInstanceState.SUCCESS: - self._register_dataset_changes(events=context["outlet_events"], session=session) - - session.commit() - if self.state == TaskInstanceState.SUCCESS: - get_listener_manager().hook.on_task_instance_success( - previous_state=TaskInstanceState.RUNNING, task_instance=self, session=session - ) - - return None + return _run_raw_task( + ti=self, + mark_success=mark_success, + test_mode=test_mode, + job_id=job_id, + pool=pool, + raise_on_defer=raise_on_defer, + session=session, + ) def _register_dataset_changes(self, *, events: OutletEventAccessors, session: Session) -> None: if TYPE_CHECKING: @@ -2709,43 +3006,12 @@ def _execute_task(self, context: Context, task_orig: Operator): return _execute_task(self, context, task_orig) @provide_session - def defer_task(self, session: Session, defer: TaskDeferred) -> None: + def defer_task(self, exception: TaskDeferred, session: Session) -> None: """Mark the task as deferred and sets up the trigger that is needed to resume it. :meta: private """ - from airflow.models.trigger import Trigger - - if TYPE_CHECKING: - assert self.task - - # First, make the trigger entry - trigger_row = Trigger.from_object(defer.trigger) - session.add(trigger_row) - session.flush() - - # Then, update ourselves so it matches the deferral request - # Keep an eye on the logic in `check_and_change_state_before_execution()` - # depending on self.next_method semantics - self.state = TaskInstanceState.DEFERRED - self.trigger_id = trigger_row.id - self.next_method = defer.method_name - self.next_kwargs = defer.kwargs or {} - - # Calculate timeout too if it was passed - if defer.timeout is not None: - self.trigger_timeout = timezone.utcnow() + defer.timeout - else: - self.trigger_timeout = None - - # If an execution_timeout is set, set the timeout to the minimum of - # it and the trigger timeout - execution_timeout = self.task.execution_timeout - if execution_timeout: - if self.trigger_timeout: - self.trigger_timeout = min(self.start_date + execution_timeout, self.trigger_timeout) - else: - self.trigger_timeout = self.start_date + execution_timeout + _defer_task(ti=self, exception=exception, session=session) def _run_execute_callback(self, context: Context, task: BaseOperator) -> None: """Functions that need to be run before a Task is executed.""" @@ -2818,53 +3084,14 @@ def _handle_reschedule( test_mode: bool = False, session: Session = NEW_SESSION, ): - # Don't record reschedule request in test mode - if test_mode: - return - - from airflow.models.dagrun import DagRun # Avoid circular import - - self.refresh_from_db(session) - - if TYPE_CHECKING: - assert self.task - - self.end_date = timezone.utcnow() - self.set_duration() - - # Lock DAG run to be sure not to get into a deadlock situation when trying to insert - # TaskReschedule which apparently also creates lock on corresponding DagRun entity - with_row_locks( - session.query(DagRun).filter_by( - dag_id=self.dag_id, - run_id=self.run_id, - ), + _handle_reschedule( + ti=self, + actual_start_date=actual_start_date, + reschedule_exception=reschedule_exception, + test_mode=test_mode, session=session, - ).one() - - # Log reschedule request - session.add( - TaskReschedule( - self.task_id, - self.dag_id, - self.run_id, - self.try_number, - actual_start_date, - self.end_date, - reschedule_exception.reschedule_date, - self.map_index, - ) ) - # set state - self.state = TaskInstanceState.UP_FOR_RESCHEDULE - - self.clear_next_method_args() - - session.merge(self) - session.commit() - self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE") - @staticmethod def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None: """ @@ -2884,16 +3111,15 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) - return tb or error.__traceback__ @classmethod - @internal_api_call - @provide_session def fetch_handle_failure_context( cls, - ti: TaskInstance | TaskInstancePydantic, + ti: TaskInstance, error: None | str | BaseException, test_mode: bool | None = None, context: Context | None = None, force_fail: bool = False, - session: Session = NEW_SESSION, + *, + session: Session, fail_stop: bool = False, ): """ @@ -2990,8 +3216,10 @@ def fetch_handle_failure_context( @internal_api_call @provide_session def save_to_db(ti: TaskInstance | TaskInstancePydantic, session: Session = NEW_SESSION): + ti = _coalesce_to_orm_ti(ti=ti, session=session) session.merge(ti) session.flush() + session.commit() @provide_session def handle_failure( @@ -3265,62 +3493,15 @@ def xcom_pull( a non-str iterable), a list of matching XComs is returned. Elements in the list is ordered by item ordering in ``task_id`` and ``map_index``. """ - if dag_id is None: - dag_id = self.dag_id - - query = XCom.get_many( - key=key, - run_id=self.run_id, - dag_ids=dag_id, + return _xcom_pull( + ti=self, task_ids=task_ids, - map_indexes=map_indexes, + dag_id=dag_id, + key=key, include_prior_dates=include_prior_dates, session=session, - ) - - # NOTE: Since we're only fetching the value field and not the whole - # class, the @recreate annotation does not kick in. Therefore we need to - # call XCom.deserialize_value() manually. - - # We are only pulling one single task. - if (task_ids is None or isinstance(task_ids, str)) and not isinstance(map_indexes, Iterable): - first = query.with_entities( - XCom.run_id, XCom.task_id, XCom.dag_id, XCom.map_index, XCom.value - ).first() - if first is None: # No matching XCom at all. - return default - if map_indexes is not None or first.map_index < 0: - return XCom.deserialize_value(first) - return LazyXComSelectSequence.from_select( - query.with_entities(XCom.value).order_by(None).statement, - order_by=[XCom.map_index], - session=session, - ) - - # At this point either task_ids or map_indexes is explicitly multi-value. - # Order return values to match task_ids and map_indexes ordering. - ordering = [] - if task_ids is None or isinstance(task_ids, str): - ordering.append(XCom.task_id) - elif task_id_whens := {tid: i for i, tid in enumerate(task_ids)}: - ordering.append(case(task_id_whens, value=XCom.task_id)) - else: - ordering.append(XCom.task_id) - if map_indexes is None or isinstance(map_indexes, int): - ordering.append(XCom.map_index) - elif isinstance(map_indexes, range): - order = XCom.map_index - if map_indexes.step < 0: - order = order.desc() - ordering.append(order) - elif map_index_whens := {map_index: i for i, map_index in enumerate(map_indexes)}: - ordering.append(case(map_index_whens, value=XCom.map_index)) - else: - ordering.append(XCom.map_index) - return LazyXComSelectSequence.from_select( - query.with_entities(XCom.value).order_by(None).statement, - order_by=ordering, - session=session, + map_indexes=map_indexes, + default=default, ) @provide_session diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py index fe1ebadc2e59f..d6f55403910b6 100644 --- a/airflow/models/xcom.py +++ b/airflow/models/xcom.py @@ -180,6 +180,7 @@ def set( """ @classmethod + @internal_api_call @provide_session def set( cls, diff --git a/airflow/serialization/pydantic/taskinstance.py b/airflow/serialization/pydantic/taskinstance.py index 2e01bf415a141..afb7d4e2dd064 100644 --- a/airflow/serialization/pydantic/taskinstance.py +++ b/airflow/serialization/pydantic/taskinstance.py @@ -21,9 +21,17 @@ from typing_extensions import Annotated +from airflow.exceptions import AirflowRescheduleException, TaskDeferred from airflow.models import Operator from airflow.models.baseoperator import BaseOperator -from airflow.models.taskinstance import TaskInstance +from airflow.models.taskinstance import ( + TaskInstance, + TaskReturnCode, + _defer_task, + _handle_reschedule, + _run_raw_task, + _set_ti_attrs, +) from airflow.serialization.pydantic.dag import DagModelPydantic from airflow.serialization.pydantic.dag_run import DagRunPydantic from airflow.utils.log.logging_mixin import LoggingMixin @@ -126,6 +134,25 @@ def clear_xcom_data(self, session: Session | None = None): def set_state(self, state, session: Session | None = None) -> bool: return TaskInstance._set_state(ti=self, state=state, session=session) + def _run_raw_task( + self, + mark_success: bool = False, + test_mode: bool = False, + job_id: str | None = None, + pool: str | None = None, + raise_on_defer: bool = False, + session: Session | None = None, + ) -> TaskReturnCode | None: + return _run_raw_task( + ti=self, + mark_success=mark_success, + test_mode=test_mode, + job_id=job_id, + pool=pool, + raise_on_defer=raise_on_defer, + session=session, + ) + def _run_execute_callback(self, context, task): TaskInstance._run_execute_callback(self=self, context=context, task=task) # type: ignore[arg-type] @@ -143,6 +170,7 @@ def xcom_pull( dag_id: str | None = None, key: str = XCOM_RETURN_KEY, include_prior_dates: bool = False, + session: Session | None = None, *, map_indexes: int | Iterable[int] | None = None, default: Any = None, @@ -150,17 +178,26 @@ def xcom_pull( """ Pull an XCom value for this task instance. - TODO: make it works for AIP-44 :param task_ids: task id or list of task ids, if None, the task_id of the current task is used :param dag_id: dag id, if None, the dag_id of the current task is used :param key: the key to identify the XCom value :param include_prior_dates: whether to include prior execution dates + :param session: the sqlalchemy session :param map_indexes: map index or list of map indexes, if None, the map_index of the current task is used :param default: the default value to return if the XCom value does not exist :return: Xcom value """ - return None + return TaskInstance.xcom_pull( + self=self, # type: ignore[arg-type] + task_ids=task_ids, + dag_id=dag_id, + key=key, + include_prior_dates=include_prior_dates, + map_indexes=map_indexes, + default=default, + session=session, + ) def xcom_push( self, @@ -172,12 +209,17 @@ def xcom_push( """ Push an XCom value for this task instance. - TODO: make it works for AIP-44 :param key: the key to identify the XCom value :param value: the value of the XCom :param execution_date: the execution date to push the XCom for """ - pass + return TaskInstance.xcom_push( + self=self, # type: ignore[arg-type] + key=key, + value=value, + execution_date=execution_date, + session=session, + ) def get_dagrun(self, session: Session | None = None) -> DagRunPydantic: """ @@ -259,7 +301,7 @@ def is_eligible_to_retry(self): def handle_failure( self, - error: None | str | Exception | KeyboardInterrupt, + error: None | str | BaseException, test_mode: bool | None = None, context: Context | None = None, force_fail: bool = False, @@ -451,6 +493,30 @@ def command_as_list( cfg_path=cfg_path, ) + def _register_dataset_changes(self, *, events, session: Session | None = None) -> None: + TaskInstance._register_dataset_changes(self=self, events=events, session=session) # type: ignore[arg-type] + + def defer_task(self, exception: TaskDeferred, session: Session | None = None): + """Defer task.""" + updated_ti = _defer_task(ti=self, exception=exception, session=session) + _set_ti_attrs(self, updated_ti) + + def _handle_reschedule( + self, + actual_start_date: datetime, + reschedule_exception: AirflowRescheduleException, + test_mode: bool = False, + session: Session | None = None, + ): + updated_ti = _handle_reschedule( + ti=self, + actual_start_date=actual_start_date, + reschedule_exception=reschedule_exception, + test_mode=test_mode, + session=session, + ) + _set_ti_attrs(self, updated_ti) # _handle_reschedule is a remote call that mutates the TI + if is_pydantic_2_installed(): TaskInstancePydantic.model_rebuild() diff --git a/airflow/utils/task_instance_session.py b/airflow/utils/task_instance_session.py index d303ef7cccc62..6234463f0d879 100644 --- a/airflow/utils/task_instance_session.py +++ b/airflow/utils/task_instance_session.py @@ -52,6 +52,9 @@ def get_current_task_instance_session() -> Session: @contextlib.contextmanager def set_current_task_instance_session(session: Session): + if InternalApiConfig.get_use_internal_api(): + yield + return global __current_task_instance_session if __current_task_instance_session: raise RuntimeError(