From 01302a1822910f77b90f23f7504fadd4c0d3f295 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 8 Nov 2024 23:55:46 +0000 Subject: [PATCH] AIP-72: Add "Get Variable" endpoint for Execution API (#43832) This commit introduces a new endpoint, `/execution/variable/{variable_key}`, in the Execution API to retrieve Variables details. Same as the Connections PR, it uses a placeholder `check_connection_access` function to validate task permissions for each request. --- .../api_fastapi/execution_api/datamodels.py | 9 ++ .../execution_api/routes/__init__.py | 11 ++- .../execution_api/routes/connections.py | 6 +- .../execution_api/routes/health.py | 4 +- .../execution_api/routes/task_instance.py | 9 +- .../execution_api/routes/variables.py | 87 +++++++++++++++++++ ...test_connection.py => test_connections.py} | 8 +- .../execution_api/routes/test_variables.py | 77 ++++++++++++++++ 8 files changed, 189 insertions(+), 22 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/routes/variables.py rename tests/api_fastapi/execution_api/routes/{test_connection.py => test_connections.py} (92%) create mode 100644 tests/api_fastapi/execution_api/routes/test_variables.py diff --git a/airflow/api_fastapi/execution_api/datamodels.py b/airflow/api_fastapi/execution_api/datamodels.py index c61718bbef214..78dbca76bce96 100644 --- a/airflow/api_fastapi/execution_api/datamodels.py +++ b/airflow/api_fastapi/execution_api/datamodels.py @@ -134,6 +134,15 @@ class ConnectionResponse(BaseModel): extra: str | None +class VariableResponse(BaseModel): + """Variable schema for responses with fields that are needed for Runtime.""" + + model_config = ConfigDict(from_attributes=True) + + key: str + val: str | None = Field(alias="value") + + # TODO: This is a placeholder for Task Identity Token schema. class TIToken(BaseModel): """Task Identity Token.""" diff --git a/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow/api_fastapi/execution_api/routes/__init__.py index c2ee885fab6a5..76479c96cf77f 100644 --- a/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,11 +17,10 @@ from __future__ import annotations from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api.routes.connections import connection_router -from airflow.api_fastapi.execution_api.routes.health import health_router -from airflow.api_fastapi.execution_api.routes.task_instance import ti_router +from airflow.api_fastapi.execution_api.routes import connections, health, task_instance, variables execution_api_router = AirflowRouter() -execution_api_router.include_router(connection_router) -execution_api_router.include_router(health_router) -execution_api_router.include_router(ti_router) +execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) +execution_api_router.include_router(health.router, tags=["Health"]) +execution_api_router.include_router(task_instance.router, prefix="/task_instance", tags=["Task Instance"]) +execution_api_router.include_router(variables.router, prefix="/variables", tags=["Variables"]) diff --git a/airflow/api_fastapi/execution_api/routes/connections.py b/airflow/api_fastapi/execution_api/routes/connections.py index 4e0c6eb007c69..553cb0785d671 100644 --- a/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow/api_fastapi/execution_api/routes/connections.py @@ -28,9 +28,7 @@ from airflow.models.connection import Connection # TODO: Add dependency on JWT token -connection_router = AirflowRouter( - prefix="/connection", - tags=["Connection"], +router = AirflowRouter( responses={status.HTTP_404_NOT_FOUND: {"description": "Connection not found"}}, ) @@ -42,7 +40,7 @@ def get_task_token() -> datamodels.TIToken: return datamodels.TIToken(ti_key="test_key") -@connection_router.get( +@router.get( "/{connection_id}", responses={ status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, diff --git a/airflow/api_fastapi/execution_api/routes/health.py b/airflow/api_fastapi/execution_api/routes/health.py index c8d903815dc69..7bf4c4a2de0af 100644 --- a/airflow/api_fastapi/execution_api/routes/health.py +++ b/airflow/api_fastapi/execution_api/routes/health.py @@ -19,9 +19,9 @@ from airflow.api_fastapi.common.router import AirflowRouter -health_router = AirflowRouter(tags=["Health"]) +router = AirflowRouter() -@health_router.get("/health") +@router.get("/health") def health() -> dict: return {"status": "healthy"} diff --git a/airflow/api_fastapi/execution_api/routes/task_instance.py b/airflow/api_fastapi/execution_api/routes/task_instance.py index 4612b0c0425bc..3ef37013f89b7 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instance.py +++ b/airflow/api_fastapi/execution_api/routes/task_instance.py @@ -35,16 +35,13 @@ from airflow.utils.state import State # TODO: Add dependency on JWT token -ti_router = AirflowRouter( - prefix="/task_instance", - tags=["Task Instance"], -) +router = AirflowRouter() log = logging.getLogger(__name__) -@ti_router.patch( +@router.patch( "/{task_instance_id}/state", status_code=status.HTTP_204_NO_CONTENT, # TODO: Add description to the operation @@ -133,7 +130,7 @@ def ti_update_state( ) -@ti_router.put( +@router.put( "/{task_instance_id}/heartbeat", status_code=status.HTTP_204_NO_CONTENT, responses={ diff --git a/airflow/api_fastapi/execution_api/routes/variables.py b/airflow/api_fastapi/execution_api/routes/variables.py new file mode 100644 index 0000000000000..79df5678aca4b --- /dev/null +++ b/airflow/api_fastapi/execution_api/routes/variables.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging + +from fastapi import Depends, HTTPException, status +from typing_extensions import Annotated + +from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.execution_api import datamodels +from airflow.models.variable import Variable + +# TODO: Add dependency on JWT token +router = AirflowRouter( + responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}}, +) + +log = logging.getLogger(__name__) + + +def get_task_token() -> datamodels.TIToken: + """TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation.""" + return datamodels.TIToken(ti_key="test_key") + + +@router.get( + "/{variable_key}", + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, + }, +) +def get_variable( + variable_key: str, + token: Annotated[datamodels.TIToken, Depends(get_task_token)], +) -> datamodels.VariableResponse: + """Get an Airflow Variable.""" + if not has_variable_access(variable_key, token): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "reason": "access_denied", + "message": f"Task does not have access to variable {variable_key}", + }, + ) + + try: + variable_value = Variable.get(variable_key) + except KeyError: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"Variable with key '{variable_key}' not found", + }, + ) + + return datamodels.VariableResponse(key=variable_key, value=variable_value) + + +def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool: + """Check if the task has access to the variable.""" + # TODO: Placeholder for actual implementation + + ti_key = token.ti_key + log.debug( + "Checking access for task instance with key '%s' to variable '%s'", + ti_key, + variable_key, + ) + return True diff --git a/tests/api_fastapi/execution_api/routes/test_connection.py b/tests/api_fastapi/execution_api/routes/test_connections.py similarity index 92% rename from tests/api_fastapi/execution_api/routes/test_connection.py rename to tests/api_fastapi/execution_api/routes/test_connections.py index 107fb8741eac6..287a5bac8019a 100644 --- a/tests/api_fastapi/execution_api/routes/test_connection.py +++ b/tests/api_fastapi/execution_api/routes/test_connections.py @@ -43,7 +43,7 @@ def test_connection_get_from_db(self, client, session): session.add(connection) session.commit() - response = client.get("/execution/connection/test_conn") + response = client.get("/execution/connections/test_conn") assert response.status_code == 200 assert response.json() == { @@ -66,7 +66,7 @@ def test_connection_get_from_db(self, client, session): {"AIRFLOW_CONN_TEST_CONN2": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}'}, ) def test_connection_get_from_env_var(self, client, session): - response = client.get("/execution/connection/test_conn2") + response = client.get("/execution/connections/test_conn2") assert response.status_code == 200 assert response.json() == { @@ -81,7 +81,7 @@ def test_connection_get_from_env_var(self, client, session): } def test_connection_get_not_found(self, client): - response = client.get("/execution/connection/non_existent_test_conn") + response = client.get("/execution/connections/non_existent_test_conn") assert response.status_code == 404 assert response.json() == { @@ -95,7 +95,7 @@ def test_connection_get_access_denied(self, client): with mock.patch( "airflow.api_fastapi.execution_api.routes.connections.has_connection_access", return_value=False ): - response = client.get("/execution/connection/test_conn") + response = client.get("/execution/connections/test_conn") # Assert response status code and detail for access denied assert response.status_code == 403 diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py b/tests/api_fastapi/execution_api/routes/test_variables.py new file mode 100644 index 0000000000000..67247e4adb955 --- /dev/null +++ b/tests/api_fastapi/execution_api/routes/test_variables.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.models.variable import Variable + +pytestmark = pytest.mark.db_test + + +class TestGetVariable: + def test_variable_get_from_db(self, client, session): + Variable.set(key="var1", value="value", session=session) + session.commit() + + response = client.get("/execution/variables/var1") + + assert response.status_code == 200 + assert response.json() == {"key": "var1", "value": "value"} + + # Remove connection + Variable.delete(key="var1", session=session) + session.commit() + + @mock.patch.dict( + "os.environ", + {"AIRFLOW_VAR_KEY1": "VALUE"}, + ) + def test_variable_get_from_env_var(self, client, session): + response = client.get("/execution/variables/key1") + + assert response.status_code == 200 + assert response.json() == {"key": "key1", "value": "VALUE"} + + def test_variable_get_not_found(self, client): + response = client.get("/execution/variables/non_existent_var") + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "Variable with key 'non_existent_var' not found", + "reason": "not_found", + } + } + + def test_variable_get_access_denied(self, client): + with mock.patch( + "airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False + ): + response = client.get("/execution/variables/key1") + + # Assert response status code and detail for access denied + assert response.status_code == 403 + assert response.json() == { + "detail": { + "reason": "access_denied", + "message": "Task does not have access to variable key1", + } + }