From 761cedd06103e10627ad62727eae0aac15387130 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Wed, 27 Nov 2024 20:37:37 +0530 Subject: [PATCH] AIP 72: Handling "deferrable" tasks in execution_api and task SDK (#44241) closes: #44137 Co-authored-by: Kaxil Naik --- .../execution_api/datamodels/taskinstance.py | 28 ++++++++++- .../execution_api/routes/task_instances.py | 25 ++++++++++ task_sdk/src/airflow/sdk/api/client.py | 9 ++++ .../airflow/sdk/api/datamodels/_generated.py | 14 +++++- .../src/airflow/sdk/execution_time/comms.py | 9 +++- .../airflow/sdk/execution_time/supervisor.py | 28 +++++++++-- .../airflow/sdk/execution_time/task_runner.py | 17 +++++-- .../tests/dags/super_basic_deferred_run.py | 37 +++++++++++++++ .../tests/execution_time/test_supervisor.py | 21 +++++++-- .../tests/execution_time/test_task_runner.py | 41 ++++++++++++++++- .../routes/test_task_instances.py | 46 +++++++++++++++++++ 11 files changed, 257 insertions(+), 18 deletions(-) create mode 100644 task_sdk/tests/dags/super_basic_deferred_run.py diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index a2be682cd60d9..ae05cc140c435 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -18,9 +18,10 @@ from __future__ import annotations import uuid -from typing import Annotated, Literal, Union +from datetime import timedelta +from typing import Annotated, Any, Literal, Union -from pydantic import Discriminator, Tag, WithJsonSchema +from pydantic import Discriminator, Field, Tag, WithJsonSchema from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.core_api.base import BaseModel @@ -60,6 +61,26 @@ class TITargetStatePayload(BaseModel): state: IntermediateTIState +class TIDeferredStatePayload(BaseModel): + """Schema for updating TaskInstance to a deferred state.""" + + state: Annotated[ + Literal[IntermediateTIState.DEFERRED], + # Specify a default in the schema, but not in code, so Pydantic marks it as required. + WithJsonSchema( + { + "type": "string", + "enum": [IntermediateTIState.DEFERRED], + "default": IntermediateTIState.DEFERRED, + } + ), + ] + classpath: str + trigger_kwargs: Annotated[dict[str, Any], Field(default_factory=dict)] + next_method: str + trigger_timeout: timedelta | None = None + + def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: """ Determine the discriminator key for TaskInstance state transitions. @@ -77,6 +98,8 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: return str(state) elif state in set(TerminalTIState): return "_terminal_" + elif state == TIState.DEFERRED: + return "deferred" return "_other_" @@ -87,6 +110,7 @@ def ti_state_discriminator(v: dict[str, str] | BaseModel) -> str: Annotated[TIEnterRunningPayload, Tag("running")], Annotated[TITerminalStatePayload, Tag("_terminal_")], Annotated[TITargetStatePayload, Tag("_other_")], + Annotated[TIDeferredStatePayload, Tag("deferred")], ], Discriminator(ti_state_discriminator), ] diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 3adbd51ff2aae..0927e92a1f84b 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -30,12 +30,14 @@ from airflow.api_fastapi.common.db.common import get_session from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, TIStateUpdate, TITerminalStatePayload, ) from airflow.models.taskinstance import TaskInstance as TI +from airflow.models.trigger import Trigger from airflow.utils import timezone from airflow.utils.state import State @@ -122,6 +124,29 @@ def ti_update_state( ) elif isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) + elif isinstance(ti_patch_payload, TIDeferredStatePayload): + # Calculate timeout if it was passed + timeout = None + if ti_patch_payload.trigger_timeout is not None: + timeout = timezone.utcnow() + ti_patch_payload.trigger_timeout + + trigger_row = Trigger( + classpath=ti_patch_payload.classpath, + kwargs=ti_patch_payload.trigger_kwargs, + ) + session.add(trigger_row) + + # TODO: HANDLE execution timeout later as it requires a call to the DB + # either get it from the serialised DAG or get it from the API + + query = update(TI).where(TI.id == ti_id_str) + query = query.values( + state=State.DEFERRED, + trigger_id=trigger_row.id, + next_method=ti_patch_payload.next_method, + next_kwargs=ti_patch_payload.trigger_kwargs, + trigger_timeout=timeout, + ) # TODO: Replace this with FastAPI's Custom Exception handling: # https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index b1ec860773d40..80965ecdaf80b 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -31,6 +31,7 @@ from airflow.sdk.api.datamodels._generated import ( ConnectionResponse, TerminalTIState, + TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, TITerminalStatePayload, @@ -116,6 +117,7 @@ def start(self, id: uuid.UUID, pid: int, when: datetime): def finish(self, id: uuid.UUID, state: TerminalTIState, when: datetime): """Tell the API server that this TI has reached a terminal state.""" + # TODO: handle the naming better. finish sounds wrong as "even" deferred is essentially finishing. body = TITerminalStatePayload(end_date=when, state=TerminalTIState(state)) self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) @@ -124,6 +126,13 @@ def heartbeat(self, id: uuid.UUID, pid: int): body = TIHeartbeatInfo(pid=pid, hostname=get_hostname()) self.client.put(f"task-instances/{id}/heartbeat", content=body.model_dump_json()) + def defer(self, id: uuid.UUID, msg): + """Tell the API server that this TI has been deferred.""" + body = TIDeferredStatePayload(**msg.model_dump(exclude_unset=True)) + + # Create a deferred state payload from msg + self.client.patch(f"task-instances/{id}/state", content=body.model_dump_json()) + class ConnectionOperations: __slots__ = ("client",) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 0cfd3a12c40a9..1d6d344968dfd 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -21,7 +21,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta from enum import Enum from typing import Annotated, Any, Literal from uuid import UUID @@ -58,6 +58,18 @@ class IntermediateTIState(str, Enum): DEFERRED = "deferred" +class TIDeferredStatePayload(BaseModel): + """ + Schema for updating TaskInstance to a deferred state. + """ + + state: Annotated[Literal["deferred"] | None, Field(title="State")] = "deferred" + classpath: Annotated[str, Field(title="Classpath")] + trigger_kwargs: Annotated[dict[str, Any] | None, Field(title="Trigger Kwargs")] = None + next_method: Annotated[str, Field(title="Next Method")] + trigger_timeout: Annotated[timedelta | None, Field(title="Trigger Timeout")] = None + + class TIEnterRunningPayload(BaseModel): """ Schema for updating TaskInstance to 'RUNNING' state with minimal required fields. diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index f57190980dd06..0e45e45700e88 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -51,6 +51,7 @@ ConnectionResponse, TaskInstance, TerminalTIState, + TIDeferredStatePayload, VariableResponse, XComResponse, ) @@ -103,6 +104,12 @@ class TaskState(BaseModel): type: Literal["TaskState"] = "TaskState" +class DeferTask(TIDeferredStatePayload): + """Update a task instance state to deferred.""" + + type: Literal["DeferTask"] = "DeferTask" + + class GetXCom(BaseModel): key: str dag_id: str @@ -123,6 +130,6 @@ class GetVariable(BaseModel): ToSupervisor = Annotated[ - Union[TaskState, GetXCom, GetConnection, GetVariable], + Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask], Field(discriminator="type"), ] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 72fe0b65a4b68..6dfbe415058d9 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -43,8 +43,9 @@ from pydantic import TypeAdapter from airflow.sdk.api.client import Client -from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState +from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( + DeferTask, GetConnection, GetVariable, GetXCom, @@ -263,6 +264,7 @@ class WatchedSubprocess: _process: psutil.Process _exit_code: int | None = None _terminal_state: str | None = None + _final_state: str | None = None _last_heartbeat: float = 0 @@ -398,9 +400,10 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 - self.client.task_instances.finish( - id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) - ) + if self.final_state in TerminalTIState: + self.client.task_instances.finish( + id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) + ) return self._exit_code def _monitor_subprocess(self): @@ -472,10 +475,20 @@ def final_state(self): Not valid before the process has finished. """ + if self._final_state: + return self._final_state if self._exit_code == 0: return self._terminal_state or TerminalTIState.SUCCESS return TerminalTIState.FAILED + @final_state.setter + def final_state(self, value): + """Setter for final_state for certain task instance stated present in IntermediateTIState.""" + # TODO: Remove the setter and manage using the final_state property + # to be taken in a follow up + if value not in TerminalTIState: + self._final_state = value + def __rich_repr__(self): yield "ti_id", self.ti_id yield "pid", self.pid @@ -518,11 +531,16 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) resp = xcom.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, DeferTask): + self.final_state = IntermediateTIState.DEFERRED + self.client.task_instances.defer(self.ti_id, msg) + resp = None else: log.error("Unhandled request", msg=msg) continue - self.stdin.write(resp + b"\n") + if resp: + self.stdin.write(resp + b"\n") # Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index edf37f6ca897d..9abf3a4796bb6 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -30,7 +30,7 @@ from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.definitions.baseoperator import BaseOperator -from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor, ToTask +from airflow.sdk.execution_time.comms import DeferTask, StartupDetails, ToSupervisor, ToTask if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -159,8 +159,19 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO next_method to support resuming from deferred # TODO: Get a real context object ti.task.execute({"task_instance": ti}) # type: ignore[attr-defined] - except TaskDeferred: - ... + except TaskDeferred as defer: + classpath, trigger_kwargs = defer.trigger.serialize() + next_method = defer.method_name + timeout = defer.timeout + msg = DeferTask( + state="deferred", + classpath=classpath, + trigger_kwargs=trigger_kwargs, + next_method=next_method, + trigger_timeout=timeout, + ) + global SUPERVISOR_COMMS + SUPERVISOR_COMMS.send_request(msg=msg, log=log) except AirflowSkipException: ... except AirflowRescheduleException: diff --git a/task_sdk/tests/dags/super_basic_deferred_run.py b/task_sdk/tests/dags/super_basic_deferred_run.py new file mode 100644 index 0000000000000..453d9e5f6c742 --- /dev/null +++ b/task_sdk/tests/dags/super_basic_deferred_run.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import datetime + +from airflow.providers.standard.sensors.date_time import DateTimeSensorAsync +from airflow.sdk.definitions.dag import dag +from airflow.utils import timezone + + +@dag() +def super_basic_deferred_run(): + DateTimeSensorAsync( + task_id="async", + target_time=str(timezone.utcnow() + datetime.timedelta(seconds=3)), + poke_interval=60, + timeout=600, + ) + + +super_basic_deferred_run() diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 9f582074586a5..2cb8c24af1b13 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -39,6 +39,7 @@ from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity from airflow.sdk.execution_time.comms import ( ConnectionResult, + DeferTask, GetConnection, GetVariable, GetXCom, @@ -53,6 +54,8 @@ if TYPE_CHECKING: import kgb +TI_ID = uuid7() + def lineno(): """Returns the current line number in our program.""" @@ -307,7 +310,7 @@ class TestHandleRequest: def watched_subprocess(self, mocker): """Fixture to provide a WatchedSubprocess instance.""" return WatchedSubprocess( - ti_id=uuid7(), + ti_id=TI_ID, pid=12345, stdin=BytesIO(), client=mocker.Mock(), @@ -319,7 +322,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql"}', + b'{"conn_id":"test_conn","conn_type":"mysql"}\n', "connections.get", ("test_conn",), ConnectionResult(conn_id="test_conn", conn_type="mysql"), @@ -327,7 +330,7 @@ def watched_subprocess(self, mocker): ), pytest.param( GetVariable(key="test_key"), - b'{"key":"test_key","value":"test_value"}', + b'{"key":"test_key","value":"test_value"}\n', "variables.get", ("test_key",), VariableResult(key="test_key", value="test_value"), @@ -335,12 +338,20 @@ def watched_subprocess(self, mocker): ), pytest.param( GetXCom(dag_id="test_dag", run_id="test_run", task_id="test_task", key="test_key"), - b'{"key":"test_key","value":"test_value"}', + b'{"key":"test_key","value":"test_value"}\n', "xcoms.get", ("test_dag", "test_run", "test_task", "test_key", -1), XComResult(key="test_key", value="test_value"), id="get_xcom", ), + pytest.param( + DeferTask(next_method="execute_callback", classpath="my-classpath"), + b"", + "task_instances.defer", + (TI_ID, DeferTask(next_method="execute_callback", classpath="my-classpath")), + "", + id="patch_task_instance_to_deferred", + ), ], ) def test_handle_requests( @@ -379,4 +390,4 @@ def test_handle_requests( mock_client_method.assert_called_once_with(*method_arg) # Verify the response was added to the buffer - assert watched_subprocess.stdin.getvalue() == expected_buffer + b"\n" + assert watched_subprocess.stdin.getvalue() == expected_buffer diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 4666f2049e806..a66f54c709fc3 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -18,6 +18,7 @@ from __future__ import annotations import uuid +from datetime import timedelta from pathlib import Path from socket import socketpair from unittest import mock @@ -27,8 +28,9 @@ from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.datamodels._generated import TaskInstance -from airflow.sdk.execution_time.comms import StartupDetails +from airflow.sdk.execution_time.comms import DeferTask, StartupDetails from airflow.sdk.execution_time.task_runner import CommsDecoder, parse, run +from airflow.utils import timezone class TestCommsDecoder: @@ -86,3 +88,40 @@ def test_run_basic(test_dags_dir: Path): ti = parse(what) run(ti, log=mock.MagicMock()) + + +def test_run_deferred_basic(test_dags_dir: Path, time_machine): + """Test that a task can transition to a deferred state.""" + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="c", try_number=1 + ), + file=str(test_dags_dir / "super_basic_deferred_run.py"), + requests_fd=0, + ) + + # Use the time machine to set the current time + instant = timezone.datetime(2024, 11, 22) + time_machine.move_to(instant, tick=False) + + # Expected DeferTask + expected_defer_task = DeferTask( + state="deferred", + classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", + trigger_kwargs={ + "end_from_trigger": False, + "moment": instant + timedelta(seconds=3), + }, + next_method="execute_complete", + trigger_timeout=None, + ) + + # Run the task + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + ti = parse(what) + run(ti, log=mock.MagicMock()) + + # send_request will only be called when the TaskDeferred exception is raised + mock_supervisor_comms.send_request.assert_called_once_with(msg=expected_defer_task, log=mock.ANY) diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index d2285e7e3a92f..32e6b5023e81f 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -23,6 +23,7 @@ from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError +from airflow.models import Trigger from airflow.models.taskinstance import TaskInstance from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState @@ -196,6 +197,51 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta assert response.status_code == 500 assert response.json()["detail"] == "Database error occurred" + def test_ti_update_state_to_deferred(self, client, session, create_task_instance, time_machine): + """ + Test that tests if the transition to deferred state is handled correctly. + """ + + ti = create_task_instance( + task_id="test_ti_update_state_to_deferred", + state=State.RUNNING, + session=session, + ) + session.commit() + + instant = timezone.datetime(2024, 11, 22) + time_machine.move_to(instant, tick=False) + + payload = { + "state": "deferred", + "trigger_kwargs": {"key": "value"}, + "classpath": "my-classpath", + "next_method": "execute_callback", + "trigger_timeout": "P1D", # 1 day + } + + response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + + assert response.status_code == 204 + assert response.text == "" + + session.expire_all() + + tis = session.query(TaskInstance).all() + assert len(tis) == 1 + + assert tis[0].state == TaskInstanceState.DEFERRED + assert tis[0].next_method == "execute_callback" + assert tis[0].next_kwargs == {"key": "value"} + # TODO: Make TI.trigger_timeout a UtcDateTime instead of DateTime + assert tis[0].trigger_timeout == timezone.datetime(2024, 11, 23).replace(tzinfo=None) + + t = session.query(Trigger).all() + assert len(t) == 1 + assert t[0].created_date == instant + assert t[0].classpath == "my-classpath" + assert t[0].kwargs == {"key": "value"} + class TestTIHealthEndpoint: def setup_method(self):