Skip to content

Commit

Permalink
AIP-72: Adding PUT Variable Endpoint for execution API (apache#44449)
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Nov 29, 2024
1 parent eee6919 commit fdd353a
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 2 deletions.
7 changes: 7 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ class VariableResponse(BaseModel):

key: str
val: str | None = Field(alias="value")


class VariablePostBody(BaseModel):
"""Request body schema for creating variables."""

value: str | None = Field(serialization_alias="val")
description: str | None = Field(default=None)
26 changes: 24 additions & 2 deletions airflow/api_fastapi/execution_api/routes/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import deps
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.api_fastapi.execution_api.datamodels.variable import VariablePostBody, VariableResponse
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
Expand Down Expand Up @@ -67,7 +67,29 @@ def get_variable(variable_key: str, token: deps.TokenDep) -> VariableResponse:
return VariableResponse(key=variable_key, value=variable_value)


def has_variable_access(variable_key: str, token: TIToken) -> bool:
@router.put(
"/{variable_key}",
status_code=status.HTTP_201_CREATED,
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def put_variable(variable_key: str, body: VariablePostBody, token: deps.TokenDep):
"""Set an Airflow Variable."""
if not has_variable_access(variable_key, token, write_access=True):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to write variable {variable_key}",
},
)
Variable.set(key=variable_key, value=body.value, description=body.description)
return {"message": "Variable successfully set"}


def has_variable_access(variable_key: str, token: TIToken, write_access: bool = False) -> bool:
"""Check if the task has access to the variable."""
# TODO: Placeholder for actual implementation

Expand Down
87 changes: 87 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,18 @@

from airflow.models.variable import Variable

from tests_common.test_utils.db import clear_db_variables

pytestmark = pytest.mark.db_test


@pytest.fixture(autouse=True)
def setup_method():
clear_db_variables()
yield
clear_db_variables()


class TestGetVariable:
def test_variable_get_from_db(self, client, session):
Variable.set(key="var1", value="value", session=session)
Expand Down Expand Up @@ -75,3 +84,81 @@ def test_variable_get_access_denied(self, client):
"message": "Task does not have access to variable key1",
}
}


class TestPostVariable:
@pytest.mark.parametrize(
"payload",
[
pytest.param({"value": "{}", "description": "description"}, id="valid-payload"),
pytest.param({"value": "{}"}, id="missing-description"),
],
)
def test_should_create_variable(self, client, payload, session):
key = "var_create"
response = client.put(
f"/execution/variables/{key}",
json=payload,
)
assert response.status_code == 201

var_from_db = session.query(Variable).where(Variable.key == "var_create").first()
assert var_from_db is not None
assert var_from_db.key == key
assert var_from_db.val == payload["value"]
if "description" in payload:
assert var_from_db.description == payload["description"]

@pytest.mark.parametrize(
"key, status_code, payload",
[
pytest.param("", 404, {"value": "{}", "description": "description"}, id="missing-key"),
pytest.param("var_create", 422, {"description": "description"}, id="missing-value"),
],
)
def test_variable_missing_fields(self, client, key, status_code, payload, session):
response = client.put(
f"/execution/variables/{key}",
json=payload,
)
assert response.status_code == status_code
if response.status_code == 422:
assert response.json()["detail"][0]["type"] == "missing"
assert response.json()["detail"][0]["msg"] == "Field required"

def test_overwriting_existing_variable(self, client, session):
key = "var_create"
Variable.set(key=key, value="value", session=session)
session.commit()

payload = {"value": "new_value"}
response = client.put(
f"/execution/variables/{key}",
json=payload,
)
assert response.status_code == 201
# variable should have been updated to the new value
var_from_db = session.query(Variable).where(Variable.key == key).first()
assert var_from_db is not None
assert var_from_db.key == key
assert var_from_db.val == payload["value"]

def test_post_variable_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False
):
key = "var_create"
payload = {"value": "{}"}
response = client.put(
f"/execution/variables/{key}",
json=payload,
)

# Assert response status code and detail for access denied
assert response.status_code == 403
assert response.json() == {
"detail": {
"reason": "access_denied",
"message": "Task does not have access to write variable var_create",
}
}

0 comments on commit fdd353a

Please sign in to comment.