Skip to content

Commit

Permalink
AIP-72: Add "XCom" GET endpoint for Execution API (apache#43894)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Nov 12, 2024
1 parent e7b4937 commit 320e429
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 31 deletions.
10 changes: 9 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Annotated, Literal, Union
from typing import Annotated, Any, Literal, Union

from pydantic import (
BaseModel,
Expand Down Expand Up @@ -143,6 +143,14 @@ class VariableResponse(BaseModel):
val: str | None = Field(alias="value")


class XComResponse(BaseModel):
"""XCom schema for responses with fields that are needed for Runtime."""

key: str
value: Any
"""The returned XCom value in a JSON-compatible format."""


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""
Expand Down
32 changes: 32 additions & 0 deletions airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Annotated

from fastapi import Depends

from airflow.api_fastapi.execution_api import datamodels


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


TokenDep = Annotated[datamodels.TIToken, Depends(get_task_token)]
5 changes: 3 additions & 2 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from __future__ import annotations

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes import connections, health, task_instance, variables
from airflow.api_fastapi.execution_api.routes import connections, health, task_instances, variables, xcoms

execution_api_router = AirflowRouter()
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instance.router, prefix="/task_instance", tags=["Task Instance"])
execution_api_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
execution_api_router.include_router(variables.router, prefix="/variables", tags=["Variables"])
execution_api_router.include_router(xcoms.router, prefix="/xcoms", tags=["XComs"])
12 changes: 3 additions & 9 deletions airflow/api_fastapi/execution_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from __future__ import annotations

import logging
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, status

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection

Expand All @@ -35,11 +34,6 @@
log = logging.getLogger(__name__)


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


@router.get(
"/{connection_id}",
responses={
Expand All @@ -49,7 +43,7 @@ def get_task_token() -> datamodels.TIToken:
)
def get_connection(
connection_id: str,
token: Annotated[datamodels.TIToken, Depends(get_task_token)],
token: deps.TokenDep,
) -> datamodels.ConnectionResponse:
"""Get an Airflow connection."""
if not has_connection_access(connection_id, token):
Expand Down
12 changes: 3 additions & 9 deletions airflow/api_fastapi/execution_api/routes/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
from __future__ import annotations

import logging
from typing import Annotated

from fastapi import Depends, HTTPException, status
from fastapi import HTTPException, status

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
Expand All @@ -34,11 +33,6 @@
log = logging.getLogger(__name__)


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


@router.get(
"/{variable_key}",
responses={
Expand All @@ -48,7 +42,7 @@ def get_task_token() -> datamodels.TIToken:
)
def get_variable(
variable_key: str,
token: Annotated[datamodels.TIToken, Depends(get_task_token)],
token: deps.TokenDep,
) -> datamodels.VariableResponse:
"""Get an Airflow Variable."""
if not has_variable_access(variable_key, token):
Expand Down
116 changes: 116 additions & 0 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
import logging
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.models.xcom import BaseXCom

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
)

log = logging.getLogger(__name__)


@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the XCom"},
},
)
def get_xcom(
dag_id: str,
run_id: str,
task_id: str,
key: str,
token: deps.TokenDep,
session: Annotated[Session, Depends(get_session)],
map_index: Annotated[int, Query()] = -1,
) -> datamodels.XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
if not has_xcom_access(key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to XCom key '{key}'",
},
)

# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
# (which automatically deserializes using the backend), we avoid potential
# performance hits from retrieving large data files into the API server.
query = BaseXCom.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
map_indexes=map_index,
limit=1,
session=session,
)

result = query.with_entities(BaseXCom.value).first()

if result is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"XCom with key '{key}' not found for task '{task_id}' in DAG '{dag_id}'",
},
)

try:
xcom_value = BaseXCom.deserialize_value(result)
except json.JSONDecodeError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"reason": "invalid_format",
"message": "XCom value is not a valid JSON",
},
)

return datamodels.XComResponse(key=key, value=xcom_value)


def has_xcom_access(xcom_key: str, token: datamodels.TIToken) -> bool:
"""Check if the task has access to the XCom."""
# TODO: Placeholder for actual implementation

ti_key = token.ti_key
log.debug(
"Checking access for task instance with key '%s' to XCom '%s'",
ti_key,
xcom_key,
)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_ti_update_state_to_running(self, client, session, create_task_instance)
session.commit()

response = client.patch(
f"/execution/task_instance/{ti.id}/state",
f"/execution/task-instances/{ti.id}/state",
json={
"state": "running",
"hostname": "random-hostname",
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_ti_update_state_conflict_if_not_queued(self, client, session, create_ta
session.commit()

response = client.patch(
f"/execution/task_instance/{ti.id}/state",
f"/execution/task-instances/{ti.id}/state",
json={
"state": "running",
"hostname": "random-hostname",
Expand Down Expand Up @@ -131,7 +131,7 @@ def test_ti_update_state_to_terminal(
session.commit()

response = client.patch(
f"/execution/task_instance/{ti.id}/state",
f"/execution/task-instances/{ti.id}/state",
json={
"state": state,
"end_date": end_date.isoformat(),
Expand All @@ -158,7 +158,7 @@ def test_ti_update_state_not_found(self, client, session):

payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"}

response = client.patch(f"/execution/task_instance/{task_instance_id}/state", json=payload)
response = client.patch(f"/execution/task-instances/{task_instance_id}/state", json=payload)
assert response.status_code == 404
assert response.json()["detail"] == {
"reason": "not_found",
Expand All @@ -183,13 +183,13 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
}

with mock.patch(
"airflow.api_fastapi.execution_api.routes.task_instance.Session.execute",
"airflow.api_fastapi.execution_api.routes.task_instances.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("queued",)), # First call returns "queued"
SQLAlchemyError("Database error"), # Second call raises an error
],
):
response = client.patch(f"/execution/task_instance/{ti.id}/state", json=payload)
response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload)
assert response.status_code == 500
assert response.json()["detail"] == "Database error occurred"

Expand Down Expand Up @@ -263,7 +263,7 @@ def test_ti_heartbeat(
assert ti.last_heartbeat_at is None

response = client.put(
f"/execution/task_instance/{task_instance_id}/heartbeat",
f"/execution/task-instances/{task_instance_id}/heartbeat",
json={"hostname": hostname, "pid": pid},
)

Expand All @@ -287,7 +287,7 @@ def test_ti_heartbeat_non_existent_task(self, client, session, create_task_insta
assert session.scalar(select(TaskInstance.id).where(TaskInstance.id == task_instance_id)) is None

response = client.put(
f"/execution/task_instance/{task_instance_id}/heartbeat",
f"/execution/task-instances/{task_instance_id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

Expand Down Expand Up @@ -315,7 +315,7 @@ def test_ti_heartbeat_when_task_not_running(self, client, session, create_task_i
task_instance_id = ti.id

response = client.put(
f"/execution/task_instance/{task_instance_id}/heartbeat",
f"/execution/task-instances/{task_instance_id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

Expand Down Expand Up @@ -352,7 +352,7 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m
time_machine.move_to(new_time, tick=False)

response = client.put(
f"/execution/task_instance/{task_instance_id}/heartbeat",
f"/execution/task-instances/{task_instance_id}/heartbeat",
json={"hostname": "random-hostname", "pid": 1547},
)

Expand Down
Loading

0 comments on commit 320e429

Please sign in to comment.