diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index d60d182582e1d..015cc231f7dd4 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -462,6 +462,10 @@ def wait(self) -> int: # If it hasn't, assume it's failed self._exit_code = self._exit_code if self._exit_code is not None else 1 + # If the process has finished in a terminal state, update the state of the TaskInstance + # to reflect the final state of the process. + # For states like `deferred`, the process will exit with 0, but the state will be updated + # by the subprocess in the `handle_requests` method. if self.final_state in TerminalTIState: self.client.task_instances.finish( id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc) @@ -590,26 +594,16 @@ def final_state(self): """ The final state of the TaskInstance. - By default this will be derived from the exit code of the task + By default, this will be derived from the exit code of the task (0=success, failed otherwise) but can be changed by the subprocess sending a TaskState message, as long as the process exits with 0 Not valid before the process has finished. """ - if self._final_state: - return self._final_state if self._exit_code == 0: return self._terminal_state or TerminalTIState.SUCCESS return TerminalTIState.FAILED - @final_state.setter - def final_state(self, value): - """Setter for final_state for certain task instance stated present in IntermediateTIState.""" - # TODO: Remove the setter and manage using the final_state property - # to be taken in a follow up - if value not in TerminalTIState: - self._final_state = value - def __rich_repr__(self): yield "ti_id", self.ti_id yield "pid", self.pid @@ -639,10 +633,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N # if isinstance(msg, TaskState): # self._terminal_state = msg.state - # elif isinstance(msg, ReadXCom): - # resp = XComResponse(key="secret", value=True) - # encoder.encode_into(resp, buffer) - # self.stdin.write(buffer + b"\n") if isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) resp = conn.model_dump_json(exclude_unset=True).encode() @@ -653,7 +643,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) resp = xcom.model_dump_json(exclude_unset=True).encode() elif isinstance(msg, DeferTask): - self.final_state = IntermediateTIState.DEFERRED + self._terminal_state = IntermediateTIState.DEFERRED self.client.task_instances.defer(self.ti_id, msg) resp = None else: @@ -694,11 +684,7 @@ def cb(sock: socket): # We could have read multiple lines in one go, yield them all while (newline_pos := buffer.find(b"\n")) != -1: - if TYPE_CHECKING: - # We send in a memoryvuew, but pretend it's a bytes, as Buffer is only in 3.12+ - line = buffer[: newline_pos + 1] - else: - line = memoryview(buffer)[: newline_pos + 1] # Include the newline character + line = buffer[: newline_pos + 1] gen.send(line) buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data @@ -759,14 +745,22 @@ def supervise( server: str | None = None, dry_run: bool = False, log_path: str | None = None, + client: Client | None = None, ) -> int: """ Run a single task execution to completion. - Returns the exit code of the process + :param ti: The task instance to run. + :param dag_path: The file path to the DAG. + :param token: Authentication token for the API client. + :param server: Base URL of the API server. + :param dry_run: If True, execute without actual task execution (simulate run). + :param log_path: Path to write logs, if required. + :param client: Optional preconfigured client for communication with the server (Mostly for tests). + :return: Exit code of the process. """ # One or the other - if (not server) ^ dry_run: + if not client and ((not server) ^ dry_run): raise ValueError(f"Can only specify one of {server=} or {dry_run=}") if not dag_path: @@ -777,8 +771,9 @@ def supervise( dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1) - limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) - client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) + if not client: + limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) + client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) start = time.monotonic() @@ -805,5 +800,5 @@ def supervise( exit_code = process.wait() end = time.monotonic() - log.debug("Task finished", exit_code=exit_code, duration=end - start) + log.info("Task finished", exit_code=exit_code, duration=end - start, final_state=process.final_state) return exit_code diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index c693a77eac205..9c8bc4942294f 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -168,13 +168,11 @@ def run(ti: RuntimeTaskInstance, log: Logger): next_method = defer.method_name timeout = defer.timeout msg = DeferTask( - state="deferred", classpath=classpath, trigger_kwargs=trigger_kwargs, next_method=next_method, trigger_timeout=timeout, ) - global SUPERVISOR_COMMS SUPERVISOR_COMMS.send_request(msg=msg, log=log) except AirflowSkipException: ... diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index f8333714498a0..cd9abae55c48c 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -259,6 +259,53 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): "timestamp": "2024-11-07T12:34:56.078901Z", } in captured_logs + def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, time_machine, mocker): + """ + Test that the supervisor handles a deferred task correctly. + + This includes ensuring the task starts and executes successfully, and that the task is deferred (via + the API client) with the expected parameters. + """ + + ti = TaskInstance( + id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1 + ) + dagfile_path = test_dags_dir / "super_basic_deferred_run.py" + + # Create a mock client to assert calls to the client + # We assume the implementation of the client is correct and only need to check the calls + mock_client = mocker.Mock(spec=sdk_client.Client) + + instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0) + time_machine.move_to(instant, tick=False) + + # Assert supervisor runs the task successfully + assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0 + + # Validate calls to the client + mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY) + mock_client.task_instances.heartbeat.assert_called_once_with(ti.id, pid=mocker.ANY) + mock_client.task_instances.defer.assert_called_once_with( + ti.id, + DeferTask( + classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger", + trigger_kwargs={"moment": "2024-11-07T12:34:59Z", "end_from_trigger": False}, + next_method="execute_complete", + ), + ) + + # We are asserting the log messages here to ensure the task ran successfully + # and mainly to get the final state of the task matches one in the DB. + assert { + "exit_code": 0, + "duration": 0.0, + "final_state": "deferred", + "event": "Task finished", + "timestamp": mocker.ANY, + "level": "info", + "logger": "supervisor", + } in captured_logs + def test_supervisor_handles_already_running_task(self): """Test that Supervisor prevents starting a Task Instance that is already running.""" ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)