Skip to content

Commit

Permalink
Modularize datamodels in Execution API (apache#44068)
Browse files Browse the repository at this point in the history
Split the `datamodels` module to granular mods
  • Loading branch information
kaxil authored Nov 15, 2024
1 parent 4dfae23 commit 5c442d3
Show file tree
Hide file tree
Showing 11 changed files with 169 additions and 70 deletions.
16 changes: 16 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
33 changes: 33 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from pydantic import BaseModel, Field


class ConnectionResponse(BaseModel):
"""Connection schema for responses with fields that are needed for Runtime."""

conn_id: str
conn_type: str
host: str | None
schema_: str | None = Field(alias="schema")
login: str | None
password: str | None
port: int | None
extra: str | None
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,9 @@

from __future__ import annotations

from typing import Annotated, Any, Literal, Union
from typing import Annotated, Literal, Union

from pydantic import (
BaseModel,
ConfigDict,
Discriminator,
Field,
Tag,
WithJsonSchema,
)
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema

from airflow.api_fastapi.common.types import UtcDateTime
from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState
Expand Down Expand Up @@ -104,40 +97,3 @@ class TIHeartbeatInfo(BaseModel):

hostname: str
pid: int


class ConnectionResponse(BaseModel):
"""Connection schema for responses with fields that are needed for Runtime."""

conn_id: str
conn_type: str
host: str | None
schema_: str | None = Field(alias="schema")
login: str | None
password: str | None
port: int | None
extra: str | None


class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(from_attributes=True)

key: str
val: str | None = Field(alias="value")


class XComResponse(BaseModel):
"""XCom schema for responses with fields that are needed for Runtime."""

key: str
value: Any
"""The returned XCom value in a JSON-compatible format."""


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""

ti_key: str
27 changes: 27 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from pydantic import BaseModel


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""

ti_key: str
29 changes: 29 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from pydantic import BaseModel, ConfigDict, Field


class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(from_attributes=True)

key: str
val: str | None = Field(alias="value")
30 changes: 30 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/xcom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any

from pydantic import BaseModel


class XComResponse(BaseModel):
"""XCom schema for responses with fields that are needed for Runtime."""

key: str
value: Any
"""The returned XCom value in a JSON-compatible format."""
8 changes: 4 additions & 4 deletions airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

from fastapi import Depends

from airflow.api_fastapi.execution_api import datamodels
from airflow.api_fastapi.execution_api.datamodels.token import TIToken


def get_task_token() -> datamodels.TIToken:
def get_task_token() -> TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")
return TIToken(ti_key="test_key")


TokenDep = Annotated[datamodels.TIToken, Depends(get_task_token)]
TokenDep = Annotated[TIToken, Depends(get_task_token)]
10 changes: 6 additions & 4 deletions airflow/api_fastapi/execution_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from fastapi import HTTPException, status

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.api_fastapi.execution_api import deps
from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection

Expand All @@ -44,7 +46,7 @@
def get_connection(
connection_id: str,
token: deps.TokenDep,
) -> datamodels.ConnectionResponse:
) -> ConnectionResponse:
"""Get an Airflow connection."""
if not has_connection_access(connection_id, token):
raise HTTPException(
Expand All @@ -64,10 +66,10 @@ def get_connection(
"message": f"Connection with ID {connection_id} not found",
},
)
return datamodels.ConnectionResponse.model_validate(connection, from_attributes=True)
return ConnectionResponse.model_validate(connection, from_attributes=True)


def has_connection_access(connection_id: str, token: datamodels.TIToken) -> bool:
def has_connection_access(connection_id: str, token: TIToken) -> bool:
"""Check if the task has access to the connection."""
# TODO: Placeholder for actual implementation

Expand Down
15 changes: 10 additions & 5 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TIEnterRunningPayload,
TIHeartbeatInfo,
TIStateUpdate,
TITerminalStatePayload,
)
from airflow.models.taskinstance import TaskInstance as TI
from airflow.utils import timezone
from airflow.utils.state import State
Expand All @@ -55,7 +60,7 @@
)
def ti_update_state(
task_instance_id: UUID,
ti_patch_payload: Annotated[datamodels.TIStateUpdate, Body()],
ti_patch_payload: Annotated[TIStateUpdate, Body()],
session: Annotated[Session, Depends(get_session)],
):
"""
Expand Down Expand Up @@ -85,7 +90,7 @@ def ti_update_state(

query = update(TI).where(TI.id == ti_id_str).values(data)

if isinstance(ti_patch_payload, datamodels.TIEnterRunningPayload):
if isinstance(ti_patch_payload, TIEnterRunningPayload):
if previous_state != State.QUEUED:
log.warning(
"Can not start Task Instance ('%s') in invalid state: %s",
Expand Down Expand Up @@ -115,7 +120,7 @@ def ti_update_state(
pid=ti_patch_payload.pid,
state=State.RUNNING,
)
elif isinstance(ti_patch_payload, datamodels.TITerminalStatePayload):
elif isinstance(ti_patch_payload, TITerminalStatePayload):
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)

# TODO: Replace this with FastAPI's Custom Exception handling:
Expand Down Expand Up @@ -143,7 +148,7 @@ def ti_update_state(
)
def ti_heartbeat(
task_instance_id: UUID,
ti_payload: datamodels.TIHeartbeatInfo,
ti_payload: TIHeartbeatInfo,
session: Annotated[Session, Depends(get_session)],
):
"""Update the heartbeat of a TaskInstance to mark it as alive & still running."""
Expand Down
13 changes: 6 additions & 7 deletions airflow/api_fastapi/execution_api/routes/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from fastapi import HTTPException, status

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.api_fastapi.execution_api import deps
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
Expand All @@ -40,10 +42,7 @@
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(
variable_key: str,
token: deps.TokenDep,
) -> datamodels.VariableResponse:
def get_variable(variable_key: str, token: deps.TokenDep) -> VariableResponse:
"""Get an Airflow Variable."""
if not has_variable_access(variable_key, token):
raise HTTPException(
Expand All @@ -65,10 +64,10 @@ def get_variable(
},
)

return datamodels.VariableResponse(key=variable_key, value=variable_value)
return VariableResponse(key=variable_key, value=variable_value)


def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool:
def has_variable_access(variable_key: str, token: TIToken) -> bool:
"""Check if the task has access to the variable."""
# TODO: Placeholder for actual implementation

Expand Down
10 changes: 6 additions & 4 deletions airflow/api_fastapi/execution_api/routes/xcoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels, deps
from airflow.api_fastapi.execution_api import deps
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
from airflow.models.xcom import BaseXCom

# TODO: Add dependency on JWT token
Expand All @@ -52,7 +54,7 @@ def get_xcom(
token: deps.TokenDep,
session: Annotated[Session, Depends(get_session)],
map_index: Annotated[int, Query()] = -1,
) -> datamodels.XComResponse:
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
if not has_xcom_access(key, token):
raise HTTPException(
Expand Down Expand Up @@ -100,10 +102,10 @@ def get_xcom(
},
)

return datamodels.XComResponse(key=key, value=xcom_value)
return XComResponse(key=key, value=xcom_value)


def has_xcom_access(xcom_key: str, token: datamodels.TIToken) -> bool:
def has_xcom_access(xcom_key: str, token: TIToken) -> bool:
"""Check if the task has access to the XCom."""
# TODO: Placeholder for actual implementation

Expand Down

0 comments on commit 5c442d3

Please sign in to comment.