From 320e4295b338250afe238b2265933c7472a20e8e Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 12 Nov 2024 12:35:50 +0000 Subject: [PATCH] AIP-72: Add "XCom" GET endpoint for Execution API (#43894) closes https://github.com/apache/airflow/issues/43839 --- .../api_fastapi/execution_api/datamodels.py | 10 +- airflow/api_fastapi/execution_api/deps.py | 32 +++++ .../execution_api/routes/__init__.py | 5 +- .../execution_api/routes/connections.py | 12 +- .../{task_instance.py => task_instances.py} | 0 .../execution_api/routes/variables.py | 12 +- .../api_fastapi/execution_api/routes/xcoms.py | 116 ++++++++++++++++++ ...ask_instance.py => test_task_instances.py} | 20 +-- .../execution_api/routes/test_xcoms.py | 83 +++++++++++++ 9 files changed, 259 insertions(+), 31 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/deps.py rename airflow/api_fastapi/execution_api/routes/{task_instance.py => task_instances.py} (100%) create mode 100644 airflow/api_fastapi/execution_api/routes/xcoms.py rename tests/api_fastapi/execution_api/routes/{test_task_instance.py => test_task_instances.py} (94%) create mode 100644 tests/api_fastapi/execution_api/routes/test_xcoms.py diff --git a/airflow/api_fastapi/execution_api/datamodels.py b/airflow/api_fastapi/execution_api/datamodels.py index 78dbca76bce96..32115c9ac5a4d 100644 --- a/airflow/api_fastapi/execution_api/datamodels.py +++ b/airflow/api_fastapi/execution_api/datamodels.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Annotated, Literal, Union +from typing import Annotated, Any, Literal, Union from pydantic import ( BaseModel, @@ -143,6 +143,14 @@ class VariableResponse(BaseModel): val: str | None = Field(alias="value") +class XComResponse(BaseModel): + """XCom schema for responses with fields that are needed for Runtime.""" + + key: str + value: Any + """The returned XCom value in a JSON-compatible format.""" + + # TODO: This is a placeholder for Task Identity Token schema. class TIToken(BaseModel): """Task Identity Token.""" diff --git a/airflow/api_fastapi/execution_api/deps.py b/airflow/api_fastapi/execution_api/deps.py new file mode 100644 index 0000000000000..4564324668a3a --- /dev/null +++ b/airflow/api_fastapi/execution_api/deps.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Annotated + +from fastapi import Depends + +from airflow.api_fastapi.execution_api import datamodels + + +def get_task_token() -> datamodels.TIToken: + """TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation.""" + return datamodels.TIToken(ti_key="test_key") + + +TokenDep = Annotated[datamodels.TIToken, Depends(get_task_token)] diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow/api_fastapi/execution_api/routes/__init__.py index 76479c96cf77f..0383503f18b87 100644 --- a/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,10 +17,11 @@ from __future__ import annotations from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.routes import connections, health, task_instance, variables +from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms execution_api_router = AirflowRouter() execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) execution_api_router.include_router(health.router, tags=["Health"]) -execution_api_router.include_router(task_instance.router, prefix="/task_instance", tags=["Task Instance"]) +execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) execution_api_router.include_router(variables.router, prefix="/variables", tags=["Variables"]) +execution_api_router.include_router(xcoms.router, prefix="/xcoms", tags=["XComs"]) diff --git a/airflow/api_fastapi/execution_api/routes/connections.py b/airflow/api_fastapi/execution_api/routes/connections.py index 32d1f7afa185b..d31cfbaeb9d0e 100644 --- a/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow/api_fastapi/execution_api/routes/connections.py @@ -18,12 +18,11 @@ from __future__ import annotations import logging -from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import HTTPException, status from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels +from airflow.api_fastapi.execution_api import datamodels, deps from airflow.exceptions import AirflowNotFoundException from airflow.models.connection import Connection @@ -35,11 +34,6 @@ log = logging.getLogger(__name__) -def get_task_token() -> datamodels.TIToken: - """TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation.""" - return datamodels.TIToken(ti_key="test_key") - - @router.get( "/{connection_id}", responses={ @@ -49,7 +43,7 @@ def get_task_token() -> datamodels.TIToken: ) def get_connection( connection_id: str, - token: Annotated[datamodels.TIToken, Depends(get_task_token)], + token: deps.TokenDep, ) -> datamodels.ConnectionResponse: """Get an Airflow connection.""" if not has_connection_access(connection_id, token): diff --git a/airflow/api_fastapi/execution_api/routes/task_instance.py b/airflow/api_fastapi/execution_api/routes/task_instances.py similarity index 100% rename from airflow/api_fastapi/execution_api/routes/task_instance.py rename to airflow/api_fastapi/execution_api/routes/task_instances.py diff --git a/airflow/api_fastapi/execution_api/routes/variables.py b/airflow/api_fastapi/execution_api/routes/variables.py index 9dcdffedf7962..1ecc7480ee368 100644 --- a/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow/api_fastapi/execution_api/routes/variables.py @@ -18,12 +18,11 @@ from __future__ import annotations import logging -from typing import Annotated -from fastapi import Depends, HTTPException, status +from fastapi import HTTPException, status from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels +from airflow.api_fastapi.execution_api import datamodels, deps from airflow.models.variable import Variable # TODO: Add dependency on JWT token @@ -34,11 +33,6 @@ log = logging.getLogger(__name__) -def get_task_token() -> datamodels.TIToken: - """TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation.""" - return datamodels.TIToken(ti_key="test_key") - - @router.get( "/{variable_key}", responses={ @@ -48,7 +42,7 @@ def get_task_token() -> datamodels.TIToken: ) def get_variable( variable_key: str, - token: Annotated[datamodels.TIToken, Depends(get_task_token)], + token: deps.TokenDep, ) -> datamodels.VariableResponse: """Get an Airflow Variable.""" if not has_variable_access(variable_key, token): diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py new file mode 100644 index 0000000000000..0f0a04ed26941 --- /dev/null +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +import logging +from typing import Annotated + +from fastapi import Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from airflow.api_fastapi.common.db.common import get_session +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.execution_api import datamodels, deps +from airflow.models.xcom import BaseXCom + +# TODO: Add dependency on JWT token +router = AirflowRouter( + responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}}, +) + +log = logging.getLogger(__name__) + + +@router.get( + "/{dag_id}/{run_id}/{task_id}/{key}", + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the XCom"}, + }, +) +def get_xcom( + dag_id: str, + run_id: str, + task_id: str, + key: str, + token: deps.TokenDep, + session: Annotated[Session, Depends(get_session)], + map_index: Annotated[int, Query()] = -1, +) -> datamodels.XComResponse: + """Get an Airflow XCom from database - not other XCom Backends.""" + if not has_xcom_access(key, token): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "reason": "access_denied", + "message": f"Task does not have access to XCom key '{key}'", + }, + ) + + # We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend. + # This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead + # retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one` + # (which automatically deserializes using the backend), we avoid potential + # performance hits from retrieving large data files into the API server. + query = BaseXCom.get_many( + run_id=run_id, + key=key, + task_ids=task_id, + dag_ids=dag_id, + map_indexes=map_index, + limit=1, + session=session, + ) + + result = query.with_entities(BaseXCom.value).first() + + if result is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'", + }, + ) + + try: + xcom_value = BaseXCom.deserialize_value(result) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "reason": "invalid_format", + "message": "XCom value is not a valid JSON", + }, + ) + + return datamodels.XComResponse(key=key, value=xcom_value) + + +def has_xcom_access(xcom_key: str, token: datamodels.TIToken) -> bool: + """Check if the task has access to the XCom.""" + # TODO: Placeholder for actual implementation + + ti_key = token.ti_key + log.debug( + "Checking access for task instance with key '%s' to XCom '%s'", + ti_key, + xcom_key, + ) + return True diff --git a/tests/api_fastapi/execution_api/routes/test_task_instance.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py similarity index 94% rename from tests/api_fastapi/execution_api/routes/test_task_instance.py rename to tests/api_fastapi/execution_api/routes/test_task_instances.py index 3c78dd4c2eb24..efb48ccb533aa 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instance.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -58,7 +58,7 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance) session.commit() response = client.patch( - f"/execution/task_instance/{ti.id}/state", + f"/execution/task-instances/{ti.id}/state", json={ "state": "running", "hostname": "random-hostname", @@ -91,7 +91,7 @@ def test_ti_update_state_conflict_if_not_queued(self, client, session, create_ta session.commit() response = client.patch( - f"/execution/task_instance/{ti.id}/state", + f"/execution/task-instances/{ti.id}/state", json={ "state": "running", "hostname": "random-hostname", @@ -131,7 +131,7 @@ def test_ti_update_state_to_terminal( session.commit() response = client.patch( - f"/execution/task_instance/{ti.id}/state", + f"/execution/task-instances/{ti.id}/state", json={ "state": state, "end_date": end_date.isoformat(), @@ -158,7 +158,7 @@ def test_ti_update_state_not_found(self, client, session): payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"} - response = client.patch(f"/execution/task_instance/{task_instance_id}/state", json=payload) + response = client.patch(f"/execution/task-instances/{task_instance_id}/state", json=payload) assert response.status_code == 404 assert response.json()["detail"] == { "reason": "not_found", @@ -183,13 +183,13 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta } with mock.patch( - "airflow.api_fastapi.execution_api.routes.task_instance.Session.execute", + "airflow.api_fastapi.execution_api.routes.task_instances.Session.execute", side_effect=[ mock.Mock(one=lambda: ("queued",)), # First call returns "queued" SQLAlchemyError("Database error"), # Second call raises an error ], ): - response = client.patch(f"/execution/task_instance/{ti.id}/state", json=payload) + response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert response.status_code == 500 assert response.json()["detail"] == "Database error occurred" @@ -263,7 +263,7 @@ def test_ti_heartbeat( assert ti.last_heartbeat_at is None response = client.put( - f"/execution/task_instance/{task_instance_id}/heartbeat", + f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": hostname, "pid": pid}, ) @@ -287,7 +287,7 @@ def test_ti_heartbeat_non_existent_task(self, client, session, create_task_insta assert session.scalar(select(TaskInstance.id).where(TaskInstance.id == task_instance_id)) is None response = client.put( - f"/execution/task_instance/{task_instance_id}/heartbeat", + f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) @@ -315,7 +315,7 @@ def test_ti_heartbeat_when_task_not_running(self, client, session, create_task_i task_instance_id = ti.id response = client.put( - f"/execution/task_instance/{task_instance_id}/heartbeat", + f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) @@ -352,7 +352,7 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m time_machine.move_to(new_time, tick=False) response = client.put( - f"/execution/task_instance/{task_instance_id}/heartbeat", + f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py b/tests/api_fastapi/execution_api/routes/test_xcoms.py new file mode 100644 index 0000000000000..a7cdde6c64fbc --- /dev/null +++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models.dagrun import DagRun +from airflow.models.xcom import XCom +from airflow.utils.session import create_session + +pytestmark = pytest.mark.db_test + + +@pytest.fixture(autouse=True) +def reset_db(): + """Reset XCom entries.""" + with create_session() as session: + session.query(DagRun).delete() + session.query(XCom).delete() + + +class TestXComsGetEndpoint: + @pytest.mark.parametrize( + ("value", "expected_value"), + [ + ("value1", "value1"), + ({"key2": "value2"}, {"key2": "value2"}), + ({"key2": "value2", "key3": ["value3"]}, {"key2": "value2", "key3": ["value3"]}), + (["value1"], ["value1"]), + ], + ) + def test_xcom_get_from_db(self, client, create_task_instance, session, value, expected_value): + """Test that XCom value is returned from the database in JSON-compatible format.""" + ti = create_task_instance() + ti.xcom_push(key="xcom_1", value=value, session=session) + + session.commit() + + response = client.get(f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1") + + assert response.status_code == 200 + assert response.json() == {"key": "xcom_1", "value": expected_value} + + def test_xcom_not_found(self, client, create_task_instance): + response = client.get("/execution/xcoms/dag/runid/task/xcom_non_existent") + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "XCom with key 'xcom_non_existent' not found for task 'task' in DAG 'dag'", + "reason": "not_found", + } + } + + def test_xcom_access_denied(self, client): + with mock.patch("airflow.api_fastapi.execution_api.routes.xcoms.has_xcom_access", return_value=False): + response = client.get("/execution/xcoms/dag/runid/task/xcom_perms") + + # Assert response status code and detail for access denied + assert response.status_code == 403 + assert response.json() == { + "detail": { + "reason": "access_denied", + "message": "Task does not have access to XCom key 'xcom_perms'", + } + }