From 33bf3ede9e7d7495b68dbdeef687ca4c53b1a2e2 Mon Sep 17 00:00:00 2001
From: Kaxil Naik <kaxilnaik@gmail.com>
Date: Mon, 2 Dec 2024 20:23:25 +0000
Subject: [PATCH] AIP-72: Improve handling of `deferred` state in Supervisor
 (#44579)

- Added `_terminal_state` in `WatchedSubprocess` to manage task final states like `deferred` directly.
- Updated `wait` method in `WatchedSubprocess` to finalize task states and call the API with the terminal states (success, failure, etc)
- Simplified the `final_state` property by removing unnecessary setter logic.
- Fixed a bug where `make_buffered_socket_reader` was using `memoryview` and `msg = decoder.validate_json(line)` expected bytes.
---
 .../airflow/sdk/execution_time/supervisor.py  | 47 +++++++++----------
 .../airflow/sdk/execution_time/task_runner.py |  2 -
 .../tests/execution_time/test_supervisor.py   | 47 +++++++++++++++++++
 3 files changed, 68 insertions(+), 28 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
index d60d182582e1d..015cc231f7dd4 100644
--- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -462,6 +462,10 @@ def wait(self) -> int:
         # If it hasn't, assume it's failed
         self._exit_code = self._exit_code if self._exit_code is not None else 1
 
+        # If the process has finished in a terminal state, update the state of the TaskInstance
+        # to reflect the final state of the process.
+        # For states like `deferred`, the process will exit with 0, but the state will be updated
+        # by the subprocess in the `handle_requests` method.
         if self.final_state in TerminalTIState:
             self.client.task_instances.finish(
                 id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
@@ -590,26 +594,16 @@ def final_state(self):
         """
         The final state of the TaskInstance.
 
-        By default this will be derived from the exit code of the task
+        By default, this will be derived from the exit code of the task
         (0=success, failed otherwise) but can be changed by the subprocess
         sending a TaskState message, as long as the process exits with 0
 
         Not valid before the process has finished.
         """
-        if self._final_state:
-            return self._final_state
         if self._exit_code == 0:
             return self._terminal_state or TerminalTIState.SUCCESS
         return TerminalTIState.FAILED
 
-    @final_state.setter
-    def final_state(self, value):
-        """Setter for final_state for certain task instance stated present in IntermediateTIState."""
-        # TODO: Remove the setter and manage using the final_state property
-        # to be taken in a follow up
-        if value not in TerminalTIState:
-            self._final_state = value
-
     def __rich_repr__(self):
         yield "ti_id", self.ti_id
         yield "pid", self.pid
@@ -639,10 +633,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
 
             # if isinstance(msg, TaskState):
             #     self._terminal_state = msg.state
-            # elif isinstance(msg, ReadXCom):
-            #     resp = XComResponse(key="secret", value=True)
-            #     encoder.encode_into(resp, buffer)
-            #     self.stdin.write(buffer + b"\n")
             if isinstance(msg, GetConnection):
                 conn = self.client.connections.get(msg.conn_id)
                 resp = conn.model_dump_json(exclude_unset=True).encode()
@@ -653,7 +643,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
                 xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
                 resp = xcom.model_dump_json(exclude_unset=True).encode()
             elif isinstance(msg, DeferTask):
-                self.final_state = IntermediateTIState.DEFERRED
+                self._terminal_state = IntermediateTIState.DEFERRED
                 self.client.task_instances.defer(self.ti_id, msg)
                 resp = None
             else:
@@ -694,11 +684,7 @@ def cb(sock: socket):
 
         # We could have read multiple lines in one go, yield them all
         while (newline_pos := buffer.find(b"\n")) != -1:
-            if TYPE_CHECKING:
-                # We send in a memoryvuew, but pretend it's a bytes, as Buffer is only in 3.12+
-                line = buffer[: newline_pos + 1]
-            else:
-                line = memoryview(buffer)[: newline_pos + 1]  # Include the newline character
+            line = buffer[: newline_pos + 1]
             gen.send(line)
             buffer = buffer[newline_pos + 1 :]  # Update the buffer with remaining data
 
@@ -759,14 +745,22 @@ def supervise(
     server: str | None = None,
     dry_run: bool = False,
     log_path: str | None = None,
+    client: Client | None = None,
 ) -> int:
     """
     Run a single task execution to completion.
 
-    Returns the exit code of the process
+    :param ti: The task instance to run.
+    :param dag_path: The file path to the DAG.
+    :param token: Authentication token for the API client.
+    :param server: Base URL of the API server.
+    :param dry_run: If True, execute without actual task execution (simulate run).
+    :param log_path: Path to write logs, if required.
+    :param client: Optional preconfigured client for communication with the server (Mostly for tests).
+    :return: Exit code of the process.
     """
     # One or the other
-    if (not server) ^ dry_run:
+    if not client and ((not server) ^ dry_run):
         raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
 
     if not dag_path:
@@ -777,8 +771,9 @@ def supervise(
 
         dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1)
 
-    limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
-    client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)
+    if not client:
+        limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
+        client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)
 
     start = time.monotonic()
 
@@ -805,5 +800,5 @@ def supervise(
 
     exit_code = process.wait()
     end = time.monotonic()
-    log.debug("Task finished", exit_code=exit_code, duration=end - start)
+    log.info("Task finished", exit_code=exit_code, duration=end - start, final_state=process.final_state)
     return exit_code
diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
index c693a77eac205..9c8bc4942294f 100644
--- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -168,13 +168,11 @@ def run(ti: RuntimeTaskInstance, log: Logger):
         next_method = defer.method_name
         timeout = defer.timeout
         msg = DeferTask(
-            state="deferred",
             classpath=classpath,
             trigger_kwargs=trigger_kwargs,
             next_method=next_method,
             trigger_timeout=timeout,
         )
-        global SUPERVISOR_COMMS
         SUPERVISOR_COMMS.send_request(msg=msg, log=log)
     except AirflowSkipException:
         ...
diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py
index f8333714498a0..cd9abae55c48c 100644
--- a/task_sdk/tests/execution_time/test_supervisor.py
+++ b/task_sdk/tests/execution_time/test_supervisor.py
@@ -259,6 +259,53 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
             "timestamp": "2024-11-07T12:34:56.078901Z",
         } in captured_logs
 
+    def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, time_machine, mocker):
+        """
+        Test that the supervisor handles a deferred task correctly.
+
+        This includes ensuring the task starts and executes successfully, and that the task is deferred (via
+        the API client) with the expected parameters.
+        """
+
+        ti = TaskInstance(
+            id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1
+        )
+        dagfile_path = test_dags_dir / "super_basic_deferred_run.py"
+
+        # Create a mock client to assert calls to the client
+        # We assume the implementation of the client is correct and only need to check the calls
+        mock_client = mocker.Mock(spec=sdk_client.Client)
+
+        instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
+        time_machine.move_to(instant, tick=False)
+
+        # Assert supervisor runs the task successfully
+        assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0
+
+        # Validate calls to the client
+        mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY)
+        mock_client.task_instances.heartbeat.assert_called_once_with(ti.id, pid=mocker.ANY)
+        mock_client.task_instances.defer.assert_called_once_with(
+            ti.id,
+            DeferTask(
+                classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
+                trigger_kwargs={"moment": "2024-11-07T12:34:59Z", "end_from_trigger": False},
+                next_method="execute_complete",
+            ),
+        )
+
+        # We are asserting the log messages here to ensure the task ran successfully
+        # and mainly to get the final state of the task matches one in the DB.
+        assert {
+            "exit_code": 0,
+            "duration": 0.0,
+            "final_state": "deferred",
+            "event": "Task finished",
+            "timestamp": mocker.ANY,
+            "level": "info",
+            "logger": "supervisor",
+        } in captured_logs
+
     def test_supervisor_handles_already_running_task(self):
         """Test that Supervisor prevents starting a Task Instance that is already running."""
         ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)