diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index eaae71dae3fd9..c97a8bc22c6b5 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -35,6 +35,7 @@ TIHeartbeatInfo, TITerminalStatePayload, ValidationError as RemoteValidationError, + VariableResponse, ) from airflow.utils.net import get_hostname from airflow.utils.platform import getuser @@ -124,17 +125,29 @@ def heartbeat(self, id: uuid.UUID, pid: int): class ConnectionOperations: - __slots__ = ("client", "decoder") + __slots__ = ("client",) def __init__(self, client: Client): self.client = client - def get(self, id: str) -> ConnectionResponse: + def get(self, conn_id: str) -> ConnectionResponse: """Get a connection from the API server.""" - resp = self.client.get(f"connection/{id}") + resp = self.client.get(f"connections/{conn_id}") return ConnectionResponse.model_validate_json(resp.read()) +class VariableOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, key: str) -> VariableResponse: + """Get a variable from the API server.""" + resp = self.client.get(f"variables/{key}") + return VariableResponse.model_validate_json(resp.read()) + + class BearerAuth(httpx.Auth): def __init__(self, token: str): self.token: str = token @@ -186,9 +199,15 @@ def task_instances(self) -> TaskInstanceOperations: @lru_cache() # type: ignore[misc] @property def connections(self) -> ConnectionOperations: - """Operations related to TaskInstances.""" + """Operations related to Connections.""" return ConnectionOperations(self) + @lru_cache() # type: ignore[misc] + @property + def variables(self) -> VariableOperations: + """Operations related to Variables.""" + return VariableOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 07b260a417d50..e8e8c16f258e3 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -43,11 +43,17 @@ from __future__ import annotations -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Literal, Union from pydantic import BaseModel, ConfigDict, Field -from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState # noqa: TCH001 +from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + TaskInstance, + TerminalTIState, + VariableResponse, + XComResponse, +) class StartupDetails(BaseModel): @@ -64,23 +70,22 @@ class StartupDetails(BaseModel): type: Literal["StartupDetails"] = "StartupDetails" -class XComResponse(BaseModel): +class XComResult(XComResponse): """Response to ReadXCom request.""" - key: str - value: Any + type: Literal["XComResult"] = "XComResult" - type: Literal["XComResponse"] = "XComResponse" +class ConnectionResult(ConnectionResponse): + type: Literal["ConnectionResult"] = "ConnectionResult" -class ConnectionResponse(BaseModel): - conn: Any - type: Literal["ConnectionResponse"] = "ConnectionResponse" +class VariableResult(VariableResponse): + type: Literal["VariableResult"] = "VariableResult" ToTask = Annotated[ - Union[StartupDetails, XComResponse, ConnectionResponse], + Union[StartupDetails, XComResult, ConnectionResult, VariableResult], Field(discriminator="type"), ] @@ -98,22 +103,22 @@ class TaskState(BaseModel): type: Literal["TaskState"] = "TaskState" -class ReadXCom(BaseModel): +class GetXCom(BaseModel): key: str - type: Literal["ReadXCom"] = "ReadXCom" + type: Literal["GetXCom"] = "GetXCom" class GetConnection(BaseModel): - id: str + conn_id: str type: Literal["GetConnection"] = "GetConnection" class GetVariable(BaseModel): - id: str + key: str type: Literal["GetVariable"] = "GetVariable" ToSupervisor = Annotated[ - Union[TaskState, ReadXCom, GetConnection, GetVariable], + Union[TaskState, GetXCom, GetConnection, GetVariable], Field(discriminator="type"), ] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 1d72f7be633e3..f3789808071ff 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -45,8 +45,8 @@ from airflow.sdk.api.client import Client from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( - ConnectionResponse, GetConnection, + GetVariable, StartupDetails, ToSupervisor, ) @@ -480,10 +480,7 @@ def __repr__(self) -> str: return rep + " >" def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, None]: - encoder = ConnectionResponse.model_dump_json - # Use a buffer to avoid small allocations - buffer = bytearray(64) - + """Handle incoming requests from the task process, respond with the appropriate data.""" decoder = TypeAdapter[ToSupervisor](ToSupervisor) while True: @@ -495,28 +492,23 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N log.exception("Unable to decode message", line=line) continue - # if isinstnace(msg, TaskState): + # if isinstance(msg, TaskState): # self._terminal_state = msg.state # elif isinstance(msg, ReadXCom): # resp = XComResponse(key="secret", value=True) # encoder.encode_into(resp, buffer) # self.stdin.write(buffer + b"\n") if isinstance(msg, GetConnection): - conn = self.client.connections.get(msg.id) - resp = ConnectionResponse(conn=conn) - encoded_resp = encoder(resp) - buffer.extend(encoded_resp.encode()) + conn = self.client.connections.get(msg.conn_id) + resp = conn.model_dump_json(exclude_unset=True).encode() + elif isinstance(msg, GetVariable): + var = self.client.variables.get(msg.key) + resp = var.model_dump_json(exclude_unset=True).encode() else: log.error("Unhandled request", msg=msg) continue - buffer.extend(b"\n") - self.stdin.write(buffer) - - # Ensure the buffer doesn't grow and stay large if a large payload is used. This won't grow it - # larger than it is, but it will shrink it - if len(buffer) > 1024: - buffer = buffer[:1024] + self.stdin.write(resp + b"\n") # Sockets, even the `.makefile()` function don't correctly do line buffering on reading. If a chunk is read diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index 40a3279f3dbab..edf37f6ca897d 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -105,7 +105,16 @@ def send_request(self, log: Logger, msg: ToSupervisor): self.request_socket.write(encoded_msg) -# This global variable will be used by Connection/Variable classes etc to send requests to +# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, +# to send requests back to the supervisor process. +# +# Why it needs to be a global: +# - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests +# to the parent process during task execution. +# - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the +# deeply nested execution stack. +# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily +# accessible wherever needed during task execution without modifying every layer of the call stack. SUPERVISOR_COMMS: CommsDecoder # State machine! diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index eb6611b0f57cb..c7b622eef2c8a 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -21,6 +21,7 @@ import pytest from airflow.sdk.api.client import Client, RemoteValidationError, ServerResponseError +from airflow.sdk.api.datamodels._generated import VariableResponse class TestClient: @@ -74,3 +75,61 @@ def handle_request(request: httpx.Request) -> httpx.Response: client.get("http://error") assert err.value.args == ("Not found",) assert err.value.detail is None + + +def make_client(transport: httpx.MockTransport) -> Client: + """Get a client with a custom transport""" + return Client(base_url="test://server", token="", transport=transport) + + +class TestVariableOperations: + """ + Test that the VariableOperations class works as expected. While the operations are simple, it + still catches the basic functionality of the client for variables including endpoint and + response parsing. + """ + + def test_variable_get_success(self): + # Simulate a successful response from the server with a variable + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/variables/test_key": + return httpx.Response( + status_code=200, + json={"key": "test_key", "value": "test_value"}, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.variables.get(key="test_key") + + assert isinstance(result, VariableResponse) + assert result.key == "test_key" + assert result.value == "test_value" + + def test_variable_not_found(self): + # Simulate a 404 response from the server + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/variables/non_existent_var": + return httpx.Response( + status_code=404, + json={ + "detail": { + "message": "Variable with key 'non_existent_var' not found", + "reason": "not_found", + } + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError) as err: + client.variables.get(key="non_existent_var") + + assert err.value.response.status_code == 404 + assert err.value.detail == { + "detail": { + "message": "Variable with key 'non_existent_var' not found", + "reason": "not_found", + } + } diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 5127c626736c1..8d1117acd0e95 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -22,18 +22,20 @@ import os import signal import sys +from io import BytesIO +from operator import attrgetter from time import sleep from typing import TYPE_CHECKING from unittest.mock import MagicMock -from uuid import UUID import pytest import structlog -import structlog.testing +from uuid6 import uuid7 from airflow.sdk.api import client as sdk_client from airflow.sdk.api.datamodels._generated import TaskInstance from airflow.sdk.api.datamodels.activities import ExecuteTaskActivity +from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, GetVariable, VariableResult from airflow.sdk.execution_time.supervisor import WatchedSubprocess, supervise from airflow.utils import timezone as tz @@ -174,12 +176,12 @@ def subprocess_main(): print("output", flush=True) sleep(0.05) - id = UUID("4d828a62-a417-4936-a7a6-2b3fabacecab") + ti_id = uuid7() spy = spy_agency.spy_on(sdk_client.TaskInstanceOperations.heartbeat) proc = WatchedSubprocess.start( path=os.devnull, ti=TaskInstance( - id=id, + id=ti_id, task_id="b", dag_id="c", run_id="d", @@ -189,7 +191,7 @@ def subprocess_main(): target=subprocess_main, ) assert proc.wait() == 0 - assert spy.called_with(id, pid=proc.pid) # noqa: PGH005 + assert spy.called_with(ti_id, pid=proc.pid) # noqa: PGH005 # The exact number we get will depend on timing behaviour, so be a little lenient assert 1 <= len(spy.calls) <= 4 @@ -205,7 +207,7 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): dagfile_path = test_dags_dir / "super_basic_run.py" task_activity = ExecuteTaskActivity( ti=TaskInstance( - id=UUID("4d828a62-a417-4936-a7a6-2b3fabacecab"), + id=uuid7(), task_id="hello", dag_id="super_basic_run", run_id="c", @@ -225,3 +227,77 @@ def test_run_simple_dag(self, test_dags_dir, captured_logs, time_machine): "logger": "task", "timestamp": "2024-11-07T12:34:56.078901Z", } in captured_logs + + +class TestHandleRequest: + @pytest.fixture + def watched_subprocess(self, mocker): + """Fixture to provide a WatchedSubprocess instance.""" + return WatchedSubprocess( + ti_id=uuid7(), + pid=12345, + stdin=BytesIO(), + stdout=mocker.Mock(), # Not used in these tests + stderr=mocker.Mock(), # Not used in these tests + client=mocker.Mock(), + process=mocker.Mock(), + ) + + @pytest.mark.parametrize( + ["message", "expected_buffer", "client_attr_path", "method_arg", "mock_response"], + [ + pytest.param( + GetConnection(conn_id="test_conn"), + b'{"conn_id":"test_conn","conn_type":"mysql"}', + "connections.get", + "test_conn", + ConnectionResult(conn_id="test_conn", conn_type="mysql"), + id="get_connection", + ), + pytest.param( + GetVariable(key="test_key"), + b'{"key":"test_key","value":"test_value"}', + "variables.get", + "test_key", + VariableResult(key="test_key", value="test_value"), + id="get_variable", + ), + ], + ) + def test_handle_requests( + self, + watched_subprocess, + mocker, + message, + expected_buffer, + client_attr_path, + method_arg, + mock_response, + ): + """ + Test handling of different messages to the subprocess. For any new message type, add a + new parameter set to the `@pytest.mark.parametrize` decorator. + + For each message type, this test: + + 1. Sends the message to the subprocess. + 2. Verifies that the correct client method is called with the expected argument. + 3. Checks that the buffer is updated with the expected response. + """ + + # Mock the client method. E.g. `client.variables.get` or `client.connections.get` + mock_client_method = attrgetter(client_attr_path)(watched_subprocess.client) + mock_client_method.return_value = mock_response + + # Simulate the generator + generator = watched_subprocess.handle_requests(log=mocker.Mock()) + # Initialize the generator + next(generator) + msg = message.model_dump_json().encode() + b"\n" + generator.send(msg) + + # Verify the correct client method was called + mock_client_method.assert_called_once_with(method_arg) + + # Verify the response was added to the buffer + assert watched_subprocess.stdin.getvalue() == expected_buffer + b"\n"