From fdd353a03e4b058fff834b611880c2475abfac61 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 29 Nov 2024 12:07:45 +0530 Subject: [PATCH] AIP-72: Adding PUT Variable Endpoint for execution API (#44449) --- .../execution_api/datamodels/variable.py | 7 ++ .../execution_api/routes/variables.py | 26 +++++- .../execution_api/routes/test_variables.py | 87 +++++++++++++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow/api_fastapi/execution_api/datamodels/variable.py index 548d593476671..ce542af0d8440 100644 --- a/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -27,3 +27,10 @@ class VariableResponse(BaseModel): key: str val: str | None = Field(alias="value") + + +class VariablePostBody(BaseModel): + """Request body schema for creating variables.""" + + value: str | None = Field(serialization_alias="val") + description: str | None = Field(default=None) diff --git a/airflow/api_fastapi/execution_api/routes/variables.py b/airflow/api_fastapi/execution_api/routes/variables.py index e8e2012e8d1e1..0e454f7dae0fc 100644 --- a/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow/api_fastapi/execution_api/routes/variables.py @@ -24,7 +24,7 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api import deps from airflow.api_fastapi.execution_api.datamodels.token import TIToken -from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse +from airflow.api_fastapi.execution_api.datamodels.variable import VariablePostBody, VariableResponse from airflow.models.variable import Variable # TODO: Add dependency on JWT token @@ -67,7 +67,29 @@ def get_variable(variable_key: str, token: deps.TokenDep) -> VariableResponse: return VariableResponse(key=variable_key, value=variable_value) -def has_variable_access(variable_key: str, token: TIToken) -> bool: +@router.put( + "/{variable_key}", + status_code=status.HTTP_201_CREATED, + responses={ + status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, + status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, + }, +) +def put_variable(variable_key: str, body: VariablePostBody, token: deps.TokenDep): + """Set an Airflow Variable.""" + if not has_variable_access(variable_key, token, write_access=True): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "reason": "access_denied", + "message": f"Task does not have access to write variable {variable_key}", + }, + ) + Variable.set(key=variable_key, value=body.value, description=body.description) + return {"message": "Variable successfully set"} + + +def has_variable_access(variable_key: str, token: TIToken, write_access: bool = False) -> bool: """Check if the task has access to the variable.""" # TODO: Placeholder for actual implementation diff --git a/tests/api_fastapi/execution_api/routes/test_variables.py b/tests/api_fastapi/execution_api/routes/test_variables.py index 67247e4adb955..9ae7f9a27395f 100644 --- a/tests/api_fastapi/execution_api/routes/test_variables.py +++ b/tests/api_fastapi/execution_api/routes/test_variables.py @@ -23,9 +23,18 @@ from airflow.models.variable import Variable +from tests_common.test_utils.db import clear_db_variables + pytestmark = pytest.mark.db_test +@pytest.fixture(autouse=True) +def setup_method(): + clear_db_variables() + yield + clear_db_variables() + + class TestGetVariable: def test_variable_get_from_db(self, client, session): Variable.set(key="var1", value="value", session=session) @@ -75,3 +84,81 @@ def test_variable_get_access_denied(self, client): "message": "Task does not have access to variable key1", } } + + +class TestPostVariable: + @pytest.mark.parametrize( + "payload", + [ + pytest.param({"value": "{}", "description": "description"}, id="valid-payload"), + pytest.param({"value": "{}"}, id="missing-description"), + ], + ) + def test_should_create_variable(self, client, payload, session): + key = "var_create" + response = client.put( + f"/execution/variables/{key}", + json=payload, + ) + assert response.status_code == 201 + + var_from_db = session.query(Variable).where(Variable.key == "var_create").first() + assert var_from_db is not None + assert var_from_db.key == key + assert var_from_db.val == payload["value"] + if "description" in payload: + assert var_from_db.description == payload["description"] + + @pytest.mark.parametrize( + "key, status_code, payload", + [ + pytest.param("", 404, {"value": "{}", "description": "description"}, id="missing-key"), + pytest.param("var_create", 422, {"description": "description"}, id="missing-value"), + ], + ) + def test_variable_missing_fields(self, client, key, status_code, payload, session): + response = client.put( + f"/execution/variables/{key}", + json=payload, + ) + assert response.status_code == status_code + if response.status_code == 422: + assert response.json()["detail"][0]["type"] == "missing" + assert response.json()["detail"][0]["msg"] == "Field required" + + def test_overwriting_existing_variable(self, client, session): + key = "var_create" + Variable.set(key=key, value="value", session=session) + session.commit() + + payload = {"value": "new_value"} + response = client.put( + f"/execution/variables/{key}", + json=payload, + ) + assert response.status_code == 201 + # variable should have been updated to the new value + var_from_db = session.query(Variable).where(Variable.key == key).first() + assert var_from_db is not None + assert var_from_db.key == key + assert var_from_db.val == payload["value"] + + def test_post_variable_access_denied(self, client): + with mock.patch( + "airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False + ): + key = "var_create" + payload = {"value": "{}"} + response = client.put( + f"/execution/variables/{key}", + json=payload, + ) + + # Assert response status code and detail for access denied + assert response.status_code == 403 + assert response.json() == { + "detail": { + "reason": "access_denied", + "message": "Task does not have access to write variable var_create", + } + }