From 5c442d378dc3d06c2a18cf7b7f6f2777e22da6f5 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 15 Nov 2024 19:08:39 +0000 Subject: [PATCH] Modularize datamodels in Execution API (#44068) Split the `datamodels` module to granular mods --- .../execution_api/datamodels/__init__.py | 16 +++++++ .../execution_api/datamodels/connection.py | 33 +++++++++++++ .../taskinstance.py} | 48 +------------------ .../execution_api/datamodels/token.py | 27 +++++++++++ .../execution_api/datamodels/variable.py | 29 +++++++++++ .../execution_api/datamodels/xcom.py | 30 ++++++++++++ airflow/api_fastapi/execution_api/deps.py | 8 ++-- .../execution_api/routes/connections.py | 10 ++-- .../execution_api/routes/task_instances.py | 15 ++++-- .../execution_api/routes/variables.py | 13 +++-- .../api_fastapi/execution_api/routes/xcoms.py | 10 ++-- 11 files changed, 169 insertions(+), 70 deletions(-) create mode 100644 airflow/api_fastapi/execution_api/datamodels/__init__.py create mode 100644 airflow/api_fastapi/execution_api/datamodels/connection.py rename airflow/api_fastapi/execution_api/{datamodels.py => datamodels/taskinstance.py} (76%) create mode 100644 airflow/api_fastapi/execution_api/datamodels/token.py create mode 100644 airflow/api_fastapi/execution_api/datamodels/variable.py create mode 100644 airflow/api_fastapi/execution_api/datamodels/xcom.py diff --git a/airflow/api_fastapi/execution_api/datamodels/__init__.py b/airflow/api_fastapi/execution_api/datamodels/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/api_fastapi/execution_api/datamodels/connection.py b/airflow/api_fastapi/execution_api/datamodels/connection.py new file mode 100644 index 0000000000000..f3c678952982e --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/connection.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class ConnectionResponse(BaseModel): + """Connection schema for responses with fields that are needed for Runtime.""" + + conn_id: str + conn_type: str + host: str | None + schema_: str | None = Field(alias="schema") + login: str | None + password: str | None + port: int | None + extra: str | None diff --git a/airflow/api_fastapi/execution_api/datamodels.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py similarity index 76% rename from airflow/api_fastapi/execution_api/datamodels.py rename to airflow/api_fastapi/execution_api/datamodels/taskinstance.py index ec8be531e103e..db63dc3a8dbb1 100644 --- a/airflow/api_fastapi/execution_api/datamodels.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -17,16 +17,9 @@ from __future__ import annotations -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Literal, Union -from pydantic import ( - BaseModel, - ConfigDict, - Discriminator, - Field, - Tag, - WithJsonSchema, -) +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema from airflow.api_fastapi.common.types import UtcDateTime from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState @@ -104,40 +97,3 @@ class TIHeartbeatInfo(BaseModel): hostname: str pid: int - - -class ConnectionResponse(BaseModel): - """Connection schema for responses with fields that are needed for Runtime.""" - - conn_id: str - conn_type: str - host: str | None - schema_: str | None = Field(alias="schema") - login: str | None - password: str | None - port: int | None - extra: str | None - - -class VariableResponse(BaseModel): - """Variable schema for responses with fields that are needed for Runtime.""" - - model_config = ConfigDict(from_attributes=True) - - key: str - val: str | None = Field(alias="value") - - -class XComResponse(BaseModel): - """XCom schema for responses with fields that are needed for Runtime.""" - - key: str - value: Any - """The returned XCom value in a JSON-compatible format.""" - - -# TODO: This is a placeholder for Task Identity Token schema. -class TIToken(BaseModel): - """Task Identity Token.""" - - ti_key: str diff --git a/airflow/api_fastapi/execution_api/datamodels/token.py b/airflow/api_fastapi/execution_api/datamodels/token.py new file mode 100644 index 0000000000000..7086c39813e33 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/token.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import BaseModel + + +# TODO: This is a placeholder for Task Identity Token schema. +class TIToken(BaseModel): + """Task Identity Token.""" + + ti_key: str diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow/api_fastapi/execution_api/datamodels/variable.py new file mode 100644 index 0000000000000..6819286f54bf6 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pydantic import BaseModel, ConfigDict, Field + + +class VariableResponse(BaseModel): + """Variable schema for responses with fields that are needed for Runtime.""" + + model_config = ConfigDict(from_attributes=True) + + key: str + val: str | None = Field(alias="value") diff --git a/airflow/api_fastapi/execution_api/datamodels/xcom.py b/airflow/api_fastapi/execution_api/datamodels/xcom.py new file mode 100644 index 0000000000000..6fb6c14629e37 --- /dev/null +++ b/airflow/api_fastapi/execution_api/datamodels/xcom.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel + + +class XComResponse(BaseModel): + """XCom schema for responses with fields that are needed for Runtime.""" + + key: str + value: Any + """The returned XCom value in a JSON-compatible format.""" diff --git a/airflow/api_fastapi/execution_api/deps.py b/airflow/api_fastapi/execution_api/deps.py index 4564324668a3a..9e409bd3d6c94 100644 --- a/airflow/api_fastapi/execution_api/deps.py +++ b/airflow/api_fastapi/execution_api/deps.py @@ -21,12 +21,12 @@ from fastapi import Depends -from airflow.api_fastapi.execution_api import datamodels +from airflow.api_fastapi.execution_api.datamodels.token import TIToken -def get_task_token() -> datamodels.TIToken: +def get_task_token() -> TIToken: """TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation.""" - return datamodels.TIToken(ti_key="test_key") + return TIToken(ti_key="test_key") -TokenDep = Annotated[datamodels.TIToken, Depends(get_task_token)] +TokenDep = Annotated[TIToken, Depends(get_task_token)] diff --git a/airflow/api_fastapi/execution_api/routes/connections.py b/airflow/api_fastapi/execution_api/routes/connections.py index d31cfbaeb9d0e..86f94f5ef3f8e 100644 --- a/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow/api_fastapi/execution_api/routes/connections.py @@ -22,7 +22,9 @@ from fastapi import HTTPException, status from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels, deps +from airflow.api_fastapi.execution_api import deps +from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse +from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.exceptions import AirflowNotFoundException from airflow.models.connection import Connection @@ -44,7 +46,7 @@ def get_connection( connection_id: str, token: deps.TokenDep, -) -> datamodels.ConnectionResponse: +) -> ConnectionResponse: """Get an Airflow connection.""" if not has_connection_access(connection_id, token): raise HTTPException( @@ -64,10 +66,10 @@ def get_connection( "message": f"Connection with ID {connection_id} not found", }, ) - return datamodels.ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection, from_attributes=True) -def has_connection_access(connection_id: str, token: datamodels.TIToken) -> bool: +def has_connection_access(connection_id: str, token: TIToken) -> bool: """Check if the task has access to the connection.""" # TODO: Placeholder for actual implementation diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 97723ffc4a873..3adbd51ff2aae 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -29,7 +29,12 @@ from airflow.api_fastapi.common.db.common import get_session from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels +from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( + TIEnterRunningPayload, + TIHeartbeatInfo, + TIStateUpdate, + TITerminalStatePayload, +) from airflow.models.taskinstance import TaskInstance as TI from airflow.utils import timezone from airflow.utils.state import State @@ -55,7 +60,7 @@ ) def ti_update_state( task_instance_id: UUID, - ti_patch_payload: Annotated[datamodels.TIStateUpdate, Body()], + ti_patch_payload: Annotated[TIStateUpdate, Body()], session: Annotated[Session, Depends(get_session)], ): """ @@ -85,7 +90,7 @@ def ti_update_state( query = update(TI).where(TI.id == ti_id_str).values(data) - if isinstance(ti_patch_payload, datamodels.TIEnterRunningPayload): + if isinstance(ti_patch_payload, TIEnterRunningPayload): if previous_state != State.QUEUED: log.warning( "Can not start Task Instance ('%s') in invalid state: %s", @@ -115,7 +120,7 @@ def ti_update_state( pid=ti_patch_payload.pid, state=State.RUNNING, ) - elif isinstance(ti_patch_payload, datamodels.TITerminalStatePayload): + elif isinstance(ti_patch_payload, TITerminalStatePayload): query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind) # TODO: Replace this with FastAPI's Custom Exception handling: @@ -143,7 +148,7 @@ def ti_update_state( ) def ti_heartbeat( task_instance_id: UUID, - ti_payload: datamodels.TIHeartbeatInfo, + ti_payload: TIHeartbeatInfo, session: Annotated[Session, Depends(get_session)], ): """Update the heartbeat of a TaskInstance to mark it as alive & still running.""" diff --git a/airflow/api_fastapi/execution_api/routes/variables.py b/airflow/api_fastapi/execution_api/routes/variables.py index 1ecc7480ee368..e8e2012e8d1e1 100644 --- a/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow/api_fastapi/execution_api/routes/variables.py @@ -22,7 +22,9 @@ from fastapi import HTTPException, status from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels, deps +from airflow.api_fastapi.execution_api import deps +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.datamodels.variable import VariableResponse from airflow.models.variable import Variable # TODO: Add dependency on JWT token @@ -40,10 +42,7 @@ status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"}, }, ) -def get_variable( - variable_key: str, - token: deps.TokenDep, -) -> datamodels.VariableResponse: +def get_variable(variable_key: str, token: deps.TokenDep) -> VariableResponse: """Get an Airflow Variable.""" if not has_variable_access(variable_key, token): raise HTTPException( @@ -65,10 +64,10 @@ def get_variable( }, ) - return datamodels.VariableResponse(key=variable_key, value=variable_value) + return VariableResponse(key=variable_key, value=variable_value) -def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool: +def has_variable_access(variable_key: str, token: TIToken) -> bool: """Check if the task has access to the variable.""" # TODO: Placeholder for actual implementation diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow/api_fastapi/execution_api/routes/xcoms.py index 0f0a04ed26941..083947923dcd9 100644 --- a/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -26,7 +26,9 @@ from airflow.api_fastapi.common.db.common import get_session from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.execution_api import datamodels, deps +from airflow.api_fastapi.execution_api import deps +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse from airflow.models.xcom import BaseXCom # TODO: Add dependency on JWT token @@ -52,7 +54,7 @@ def get_xcom( token: deps.TokenDep, session: Annotated[Session, Depends(get_session)], map_index: Annotated[int, Query()] = -1, -) -> datamodels.XComResponse: +) -> XComResponse: """Get an Airflow XCom from database - not other XCom Backends.""" if not has_xcom_access(key, token): raise HTTPException( @@ -100,10 +102,10 @@ def get_xcom( }, ) - return datamodels.XComResponse(key=key, value=xcom_value) + return XComResponse(key=key, value=xcom_value) -def has_xcom_access(xcom_key: str, token: datamodels.TIToken) -> bool: +def has_xcom_access(xcom_key: str, token: TIToken) -> bool: """Check if the task has access to the XCom.""" # TODO: Placeholder for actual implementation