Skip to content

Commit

Permalink
AIP-72: Improve handling of deferred state in Supervisor (apache#44579
Browse files Browse the repository at this point in the history
)

- Added `_terminal_state` in `WatchedSubprocess` to manage task final states like `deferred` directly.
- Updated `wait` method in `WatchedSubprocess` to finalize task states and call the API with the terminal states (success, failure, etc)
- Simplified the `final_state` property by removing unnecessary setter logic.
- Fixed a bug where `make_buffered_socket_reader` was using `memoryview` and `msg = decoder.validate_json(line)` expected bytes.
  • Loading branch information
kaxil authored Dec 2, 2024
1 parent 5bddd13 commit 33bf3ed
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 28 deletions.
47 changes: 21 additions & 26 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ def wait(self) -> int:
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1

# If the process has finished in a terminal state, update the state of the TaskInstance
# to reflect the final state of the process.
# For states like `deferred`, the process will exit with 0, but the state will be updated
# by the subprocess in the `handle_requests` method.
if self.final_state in TerminalTIState:
self.client.task_instances.finish(
id=self.ti_id, state=self.final_state, when=datetime.now(tz=timezone.utc)
Expand Down Expand Up @@ -590,26 +594,16 @@ def final_state(self):
"""
The final state of the TaskInstance.
By default this will be derived from the exit code of the task
By default, this will be derived from the exit code of the task
(0=success, failed otherwise) but can be changed by the subprocess
sending a TaskState message, as long as the process exits with 0
Not valid before the process has finished.
"""
if self._final_state:
return self._final_state
if self._exit_code == 0:
return self._terminal_state or TerminalTIState.SUCCESS
return TerminalTIState.FAILED

@final_state.setter
def final_state(self, value):
"""Setter for final_state for certain task instance stated present in IntermediateTIState."""
# TODO: Remove the setter and manage using the final_state property
# to be taken in a follow up
if value not in TerminalTIState:
self._final_state = value

def __rich_repr__(self):
yield "ti_id", self.ti_id
yield "pid", self.pid
Expand Down Expand Up @@ -639,10 +633,6 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N

# if isinstance(msg, TaskState):
# self._terminal_state = msg.state
# elif isinstance(msg, ReadXCom):
# resp = XComResponse(key="secret", value=True)
# encoder.encode_into(resp, buffer)
# self.stdin.write(buffer + b"\n")
if isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
resp = conn.model_dump_json(exclude_unset=True).encode()
Expand All @@ -653,7 +643,7 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
resp = xcom.model_dump_json(exclude_unset=True).encode()
elif isinstance(msg, DeferTask):
self.final_state = IntermediateTIState.DEFERRED
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
else:
Expand Down Expand Up @@ -694,11 +684,7 @@ def cb(sock: socket):

# We could have read multiple lines in one go, yield them all
while (newline_pos := buffer.find(b"\n")) != -1:
if TYPE_CHECKING:
# We send in a memoryvuew, but pretend it's a bytes, as Buffer is only in 3.12+
line = buffer[: newline_pos + 1]
else:
line = memoryview(buffer)[: newline_pos + 1] # Include the newline character
line = buffer[: newline_pos + 1]
gen.send(line)
buffer = buffer[newline_pos + 1 :] # Update the buffer with remaining data

Expand Down Expand Up @@ -759,14 +745,22 @@ def supervise(
server: str | None = None,
dry_run: bool = False,
log_path: str | None = None,
client: Client | None = None,
) -> int:
"""
Run a single task execution to completion.
Returns the exit code of the process
:param ti: The task instance to run.
:param dag_path: The file path to the DAG.
:param token: Authentication token for the API client.
:param server: Base URL of the API server.
:param dry_run: If True, execute without actual task execution (simulate run).
:param log_path: Path to write logs, if required.
:param client: Optional preconfigured client for communication with the server (Mostly for tests).
:return: Exit code of the process.
"""
# One or the other
if (not server) ^ dry_run:
if not client and ((not server) ^ dry_run):
raise ValueError(f"Can only specify one of {server=} or {dry_run=}")

if not dag_path:
Expand All @@ -777,8 +771,9 @@ def supervise(

dag_path = str_path.replace("DAGS_FOLDER/", DAGS_FOLDER + "/", 1)

limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)
if not client:
limits = httpx.Limits(max_keepalive_connections=1, max_connections=10)
client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token)

start = time.monotonic()

Expand All @@ -805,5 +800,5 @@ def supervise(

exit_code = process.wait()
end = time.monotonic()
log.debug("Task finished", exit_code=exit_code, duration=end - start)
log.info("Task finished", exit_code=exit_code, duration=end - start, final_state=process.final_state)
return exit_code
2 changes: 0 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,11 @@ def run(ti: RuntimeTaskInstance, log: Logger):
next_method = defer.method_name
timeout = defer.timeout
msg = DeferTask(
state="deferred",
classpath=classpath,
trigger_kwargs=trigger_kwargs,
next_method=next_method,
trigger_timeout=timeout,
)
global SUPERVISOR_COMMS
SUPERVISOR_COMMS.send_request(msg=msg, log=log)
except AirflowSkipException:
...
Expand Down
47 changes: 47 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,53 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine):
"timestamp": "2024-11-07T12:34:56.078901Z",
} in captured_logs

def test_supervise_handles_deferred_task(self, test_dags_dir, captured_logs, time_machine, mocker):
"""
Test that the supervisor handles a deferred task correctly.
This includes ensuring the task starts and executes successfully, and that the task is deferred (via
the API client) with the expected parameters.
"""

ti = TaskInstance(
id=uuid7(), task_id="async", dag_id="super_basic_deferred_run", run_id="d", try_number=1
)
dagfile_path = test_dags_dir / "super_basic_deferred_run.py"

# Create a mock client to assert calls to the client
# We assume the implementation of the client is correct and only need to check the calls
mock_client = mocker.Mock(spec=sdk_client.Client)

instant = tz.datetime(2024, 11, 7, 12, 34, 56, 0)
time_machine.move_to(instant, tick=False)

# Assert supervisor runs the task successfully
assert supervise(ti=ti, dag_path=dagfile_path, token="", client=mock_client) == 0

# Validate calls to the client
mock_client.task_instances.start.assert_called_once_with(ti.id, mocker.ANY, mocker.ANY)
mock_client.task_instances.heartbeat.assert_called_once_with(ti.id, pid=mocker.ANY)
mock_client.task_instances.defer.assert_called_once_with(
ti.id,
DeferTask(
classpath="airflow.providers.standard.triggers.temporal.DateTimeTrigger",
trigger_kwargs={"moment": "2024-11-07T12:34:59Z", "end_from_trigger": False},
next_method="execute_complete",
),
)

# We are asserting the log messages here to ensure the task ran successfully
# and mainly to get the final state of the task matches one in the DB.
assert {
"exit_code": 0,
"duration": 0.0,
"final_state": "deferred",
"event": "Task finished",
"timestamp": mocker.ANY,
"level": "info",
"logger": "supervisor",
} in captured_logs

def test_supervisor_handles_already_running_task(self):
"""Test that Supervisor prevents starting a Task Instance that is already running."""
ti = TaskInstance(id=uuid7(), task_id="b", dag_id="c", run_id="d", try_number=1)
Expand Down

0 comments on commit 33bf3ed

Please sign in to comment.