From 4de24a18d0b917685e523bf20f997e8a2ea25135 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Thu, 19 Dec 2024 15:21:13 +0530 Subject: [PATCH] AIP-72: Allow retrieving Connection from Task Context (#45043) part of https://github.com/apache/airflow/issues/44481 - Added a minimal Connection user-facing object in Task SDK definition for use in the DAG file - Added logic to get Connections in the context. Fixed some bugs in the way related to Connection parsing/serializing! Now, we have following Connection related objects: - `ConnectionResponse` is auto-generated and tightly coupled with the API schema. - `ConnectionResult` is runtime-specific and meant for internal communication between Supervisor & Task Runner. - `Connection` class here is where the public-facing, user-relevant aspects are exposed, hiding internal details. **Next up**: - Same for XCom & Variable - Implementation of BaseHook.get_conn Tested it with a DAG: image DAG: ```py from __future__ import annotations from airflow.models.baseoperator import BaseOperator from airflow.models.dag import dag class CustomOperator(BaseOperator): def execute(self, context): import os os.environ["AIRFLOW_CONN_AIRFLOW_DB"] = "sqlite:///home/airflow/airflow.db" task_id = context["task_instance"].task_id print(f"Hello World {task_id}!") print(context) print(context["conn"].airflow_db) assert context["conn"].airflow_db.conn_id == "airflow_db" @dag() def super_basic_run(): CustomOperator(task_id="hello") super_basic_run() ``` For case where a **connection is not found** image --- task_sdk/src/airflow/sdk/__init__.py | 3 + task_sdk/src/airflow/sdk/api/client.py | 17 ++- .../src/airflow/sdk/definitions/connection.py | 52 +++++++++ task_sdk/src/airflow/sdk/exceptions.py | 21 ++++ .../src/airflow/sdk/execution_time/comms.py | 16 ++- .../src/airflow/sdk/execution_time/context.py | 78 +++++++++++++ .../airflow/sdk/execution_time/supervisor.py | 9 +- .../airflow/sdk/execution_time/task_runner.py | 13 ++- task_sdk/tests/execution_time/test_context.py | 103 ++++++++++++++++++ .../tests/execution_time/test_supervisor.py | 2 +- .../tests/execution_time/test_task_runner.py | 68 +++++++++++- 11 files changed, 373 insertions(+), 9 deletions(-) create mode 100644 task_sdk/src/airflow/sdk/definitions/connection.py create mode 100644 task_sdk/src/airflow/sdk/execution_time/context.py create mode 100644 task_sdk/tests/execution_time/test_context.py diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index bd882f43dd0b3..13f2819a1c6e6 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -25,6 +25,7 @@ "Label", "TaskGroup", "dag", + "Connection", "__version__", ] @@ -32,6 +33,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.baseoperator import BaseOperator + from airflow.sdk.definitions.connection import Connection from airflow.sdk.definitions.dag import DAG, dag from airflow.sdk.definitions.edges import EdgeModifier, Label from airflow.sdk.definitions.taskgroup import TaskGroup @@ -43,6 +45,7 @@ "TaskGroup": ".definitions.taskgroup", "EdgeModifier": ".definitions.edges", "Label": ".definitions.edges", + "Connection": ".definitions.connection", } diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 787dcf55ab818..da91c2bd98dd2 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -19,6 +19,7 @@ import sys import uuid +from http import HTTPStatus from typing import TYPE_CHECKING, Any, TypeVar import httpx @@ -43,6 +44,8 @@ VariableResponse, XComResponse, ) +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ErrorResponse from airflow.utils.net import get_hostname from airflow.utils.platform import getuser @@ -161,9 +164,19 @@ class ConnectionOperations: def __init__(self, client: Client): self.client = client - def get(self, conn_id: str) -> ConnectionResponse: + def get(self, conn_id: str) -> ConnectionResponse | ErrorResponse: """Get a connection from the API server.""" - resp = self.client.get(f"connections/{conn_id}") + try: + resp = self.client.get(f"connections/{conn_id}") + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.error( + "Connection not found", + conn_id=conn_id, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": conn_id}) return ConnectionResponse.model_validate_json(resp.read()) diff --git a/task_sdk/src/airflow/sdk/definitions/connection.py b/task_sdk/src/airflow/sdk/definitions/connection.py new file mode 100644 index 0000000000000..628b72e29be6e --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/connection.py @@ -0,0 +1,52 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import attrs + + +@attrs.define +class Connection: + """ + A connection to an external data source. + + :param conn_id: The connection ID. + :param conn_type: The connection type. + :param description: The connection description. + :param host: The host. + :param login: The login. + :param password: The password. + :param schema: The schema. + :param port: The port number. + :param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON + encoded object. + """ + + conn_id: str + conn_type: str + description: str | None = None + host: str | None = None + schema: str | None = None + login: str | None = None + password: str | None = None + port: int | None = None + extra: str | None = None + + def get_uri(self): ... + + def get_hook(self): ... diff --git a/task_sdk/src/airflow/sdk/exceptions.py b/task_sdk/src/airflow/sdk/exceptions.py index 13a83393a9124..c713f38eef861 100644 --- a/task_sdk/src/airflow/sdk/exceptions.py +++ b/task_sdk/src/airflow/sdk/exceptions.py @@ -14,3 +14,24 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from __future__ import annotations + +import enum +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.sdk.execution_time.comms import ErrorResponse + + +class AirflowRuntimeError(Exception): + def __init__(self, error: ErrorResponse): + self.error = error + super().__init__(f"{error.error.value}: {error.detail}") + + +class ErrorType(enum.Enum): + CONNECTION_NOT_FOUND = "CONNECTION_NOT_FOUND" + VARIABLE_NOT_FOUND = "VARIABLE_NOT_FOUND" + XCOM_NOT_FOUND = "XCOM_NOT_FOUND" + GENERIC_ERROR = "GENERIC_ERROR" diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index e1a3ce034611f..31690815af4f6 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -59,6 +59,7 @@ VariableResponse, XComResponse, ) +from airflow.sdk.exceptions import ErrorType class StartupDetails(BaseModel): @@ -85,13 +86,26 @@ class XComResult(XComResponse): class ConnectionResult(ConnectionResponse): type: Literal["ConnectionResult"] = "ConnectionResult" + @classmethod + def from_conn_response(cls, connection_response: ConnectionResponse) -> ConnectionResult: + # Exclude defaults to avoid sending unnecessary data + # Pass the type as ConnectionResult explicitly so we can then call model_dump_json with exclude_unset=True + # to avoid sending unset fields (which are defaults in our case). + return cls(**connection_response.model_dump(exclude_defaults=True), type="ConnectionResult") + class VariableResult(VariableResponse): type: Literal["VariableResult"] = "VariableResult" +class ErrorResponse(BaseModel): + error: ErrorType = ErrorType.GENERIC_ERROR + detail: dict | None = None + type: Literal["ErrorResponse"] = "ErrorResponse" + + ToTask = Annotated[ - Union[StartupDetails, XComResult, ConnectionResult, VariableResult], + Union[StartupDetails, XComResult, ConnectionResult, VariableResult, ErrorResponse], Field(discriminator="type"), ] diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py new file mode 100644 index 0000000000000..30295a84f9d29 --- /dev/null +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import structlog + +from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType + +if TYPE_CHECKING: + from airflow.sdk.definitions.connection import Connection + from airflow.sdk.execution_time.comms import ConnectionResult + + +def _convert_connection_result_conn(conn_result: ConnectionResult): + from airflow.sdk.definitions.connection import Connection + + # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model + return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) + + +def _get_connection(conn_id: str) -> Connection: + # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` + # or `airflow.sdk.execution_time.connection` + # A reason to not move it to `airflow.sdk.execution_time.comms` is that it + # will make that module depend on Task SDK, which is not ideal because we intend to + # keep Task SDK as a separate package than execution time mods. + from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id)) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, ConnectionResult) + return _convert_connection_result_conn(msg) + + +class ConnectionAccessor: + """Wrapper to access Connection entries in template.""" + + def __getattr__(self, conn_id: str) -> Any: + return _get_connection(conn_id) + + def __repr__(self) -> str: + return "" + + def __eq__(self, other): + if not isinstance(other, ConnectionAccessor): + return False + # All instances of ConnectionAccessor are equal since it is a stateless dynamic accessor + return True + + def get(self, conn_id: str, default_conn: Any = None) -> Any: + try: + return _get_connection(conn_id) + except AirflowRuntimeError as e: + if e.error.error == ErrorType.CONNECTION_NOT_FOUND: + return default_conn + raise diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index bf2aa0778c3b2..1c2a49b9e748e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -44,12 +44,15 @@ from airflow.sdk.api.client import Client, ServerResponseError from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, IntermediateTIState, TaskInstance, TerminalTIState, ) from airflow.sdk.execution_time.comms import ( + ConnectionResult, DeferTask, + ErrorResponse, GetConnection, GetVariable, GetXCom, @@ -689,7 +692,11 @@ def _handle_request(self, msg, log): self._task_end_time_monotonic = time.monotonic() elif isinstance(msg, GetConnection): conn = self.client.connections.get(msg.conn_id) - resp = conn.model_dump_json(exclude_unset=True).encode() + if isinstance(conn, ConnectionResponse): + conn_result = ConnectionResult.from_conn_response(conn) + resp = conn_result.model_dump_json(exclude_unset=True).encode() + elif isinstance(conn, ErrorResponse): + resp = conn.model_dump_json().encode() elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) resp = var.model_dump_json(exclude_unset=True).encode() diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index ba4ed881039e7..7844da816baac 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,6 +40,7 @@ ToSupervisor, ToTask, ) +from airflow.sdk.execution_time.context import ConnectionAccessor if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -53,6 +54,9 @@ class RuntimeTaskInstance(TaskInstance): """The Task Instance context from the API server, if any.""" def get_template_context(self): + # TODO: Move this to `airflow.sdk.execution_time.context` + # once we port the entire context logic from airflow/utils/context.py ? + # TODO: Assess if we need to it through airflow.utils.timezone.coerce_datetime() context: dict[str, Any] = { # From the Task Execution interface @@ -63,6 +67,8 @@ def get_template_context(self): "run_id": self.run_id, "task": self.task, "task_instance": self, + # TODO: Ensure that ti.log_url and such are available to use in context + # especially after removal of `conf` from Context. "ti": self, # "outlet_events": OutletEventAccessors(), # "expanded_ti_count": expanded_ti_count, @@ -73,14 +79,13 @@ def get_template_context(self): # "prev_data_interval_end_success": get_prev_data_interval_end_success(), # "prev_start_date_success": get_prev_start_date_success(), # "prev_end_date_success": get_prev_end_date_success(), - # "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}", # "test_mode": task_instance.test_mode, # "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events), # "var": { # "json": VariableAccessor(deserialize_json=True), # "value": VariableAccessor(deserialize_json=False), # }, - # "conn": ConnectionAccessor(), + "conn": ConnectionAccessor(), } if self._ti_context_from_server: dag_run = self._ti_context_from_server.dag_run @@ -108,6 +113,10 @@ def get_template_context(self): context.update(context_from_server) return context + def xcom_pull(self, *args, **kwargs): ... + + def xcom_push(self, *args, **kwargs): ... + def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Task-SDK: diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py new file mode 100644 index 0000000000000..65d2b50f8a17f --- /dev/null +++ b/task_sdk/tests/execution_time/test_context.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +from airflow.sdk.definitions.connection import Connection +from airflow.sdk.exceptions import ErrorType +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse +from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn + + +def test_convert_connection_result_conn(): + """Test that the ConnectionResult is converted to a Connection object.""" + conn = ConnectionResult( + conn_id="test_conn", + conn_type="mysql", + host="mysql", + schema="airflow", + login="root", + password="password", + port=1234, + extra='{"extra_key": "extra_value"}', + ) + conn = _convert_connection_result_conn(conn) + assert conn == Connection( + conn_id="test_conn", + conn_type="mysql", + host="mysql", + schema="airflow", + login="root", + password="password", + port=1234, + extra='{"extra_key": "extra_value"}', + ) + + +class TestConnectionAccessor: + def test_getattr_connection(self): + """ + Test that the connection is fetched when accessed via __getattr__. + + The __getattr__ method is used for template rendering. Example: ``{{ conn.mysql_conn.host }}``. + """ + accessor = ConnectionAccessor() + + # Conn from the supervisor / API Server + conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = conn_result + + # Fetch the connection; Triggers __getattr__ + conn = accessor.mysql_conn + + expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) + assert conn == expected_conn + + def test_get_method_valid_connection(self): + """Test that the get method returns the requested connection using `conn.get`.""" + accessor = ConnectionAccessor() + conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = conn_result + + conn = accessor.get("mysql_conn") + assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306) + + def test_get_method_with_default(self): + """Test that the get method returns the default connection when the requested connection is not found.""" + accessor = ConnectionAccessor() + default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"} + error_response = ErrorResponse( + error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"} + ) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = error_response + + conn = accessor.get("nonexistent_conn", default_conn=default_conn) + assert conn == default_conn diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index be1d945b82e4c..73ac6dea630ba 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -764,7 +764,7 @@ def watched_subprocess(self, mocker): [ pytest.param( GetConnection(conn_id="test_conn"), - b'{"conn_id":"test_conn","conn_type":"mysql"}\n', + b'{"conn_id":"test_conn","conn_type":"mysql","type":"ConnectionResult"}\n', "connections.get", ("test_conn",), ConnectionResult(conn_id="test_conn", conn_type="mysql"), diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 35ff65414f837..78f8058accc0b 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -27,9 +27,17 @@ from uuid6 import uuid7 from airflow.exceptions import AirflowFailException, AirflowSensorTimeout, AirflowSkipException -from airflow.sdk import DAG, BaseOperator +from airflow.sdk import DAG, BaseOperator, Connection from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState -from airflow.sdk.execution_time.comms import DeferTask, SetRenderedFields, StartupDetails, TaskState +from airflow.sdk.execution_time.comms import ( + ConnectionResult, + DeferTask, + GetConnection, + SetRenderedFields, + StartupDetails, + TaskState, +) +from airflow.sdk.execution_time.context import ConnectionAccessor from airflow.sdk.execution_time.task_runner import CommsDecoder, RuntimeTaskInstance, parse, run, startup from airflow.utils import timezone @@ -399,6 +407,7 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ # Verify the context keys and values assert context == { + "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, "map_index_template": task.map_index_template, @@ -431,6 +440,7 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con context = runtime_ti.get_template_context() assert context == { + "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, "map_index_template": task.map_index_template, @@ -450,3 +460,57 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con "ts_nodash": "20241201T010000", "ts_nodash_with_tz": "20241201T010000+0000", } + + def test_get_connection_from_context(self, mocked_parse, make_ti_context): + """Test that the connection is fetched from the API server via the Supervisor lazily when accessed""" + + task = BaseOperator(task_id="hello") + + ti_id = uuid7() + ti = TaskInstance( + id=ti_id, task_id=task.task_id, dag_id="basic_task", run_id="test_run", try_number=1 + ) + conn = ConnectionResult( + conn_id="test_conn", + conn_type="mysql", + host="mysql", + schema="airflow", + login="root", + password="password", + port=1234, + extra='{"extra_key": "extra_value"}', + ) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = conn + + context = runtime_ti.get_template_context() + + # Assert that the connection is not fetched from the API server yet! + # The connection should be only fetched connection is accessed + mock_supervisor_comms.send_request.assert_not_called() + mock_supervisor_comms.get_message.assert_not_called() + + # Access the connection from the context + conn_from_context = context["conn"].test_conn + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, msg=GetConnection(conn_id="test_conn") + ) + mock_supervisor_comms.get_message.assert_called_once_with() + + assert conn_from_context == Connection( + conn_id="test_conn", + conn_type="mysql", + description=None, + host="mysql", + schema="airflow", + login="root", + password="password", + port=1234, + extra='{"extra_key": "extra_value"}', + )