diff --git a/airflow/api_fastapi/core_api/base.py b/airflow/api_fastapi/core_api/base.py new file mode 100644 index 0000000000000..52df0e6fea5de --- /dev/null +++ b/airflow/api_fastapi/core_api/base.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class BaseModel(PydanticBaseModel): + """ + Base pydantic model for REST API. + + :meta private: + """ + + model_config = ConfigDict(from_attributes=True) diff --git a/airflow/api_fastapi/core_api/datamodels/assets.py b/airflow/api_fastapi/core_api/datamodels/assets.py index 94ec17ad63d9d..adc32c2e4808f 100644 --- a/airflow/api_fastapi/core_api/datamodels/assets.py +++ b/airflow/api_fastapi/core_api/datamodels/assets.py @@ -19,8 +19,9 @@ from datetime import datetime -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.log.secrets_masker import redact diff --git a/airflow/api_fastapi/core_api/datamodels/backfills.py b/airflow/api_fastapi/core_api/datamodels/backfills.py index 69d6a98ccfd1a..be04063907a9d 100644 --- a/airflow/api_fastapi/core_api/datamodels/backfills.py +++ b/airflow/api_fastapi/core_api/datamodels/backfills.py @@ -19,8 +19,7 @@ from datetime import datetime -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.backfill import ReprocessBehavior diff --git a/airflow/api_fastapi/core_api/datamodels/config.py b/airflow/api_fastapi/core_api/datamodels/config.py index 0627832e45f4c..c16aa98093fb1 100644 --- a/airflow/api_fastapi/core_api/datamodels/config.py +++ b/airflow/api_fastapi/core_api/datamodels/config.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class ConfigOption(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow/api_fastapi/core_api/datamodels/connections.py index 7b23682cc8efa..d74ced1ba4d33 100644 --- a/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow/api_fastapi/core_api/datamodels/connections.py @@ -19,9 +19,10 @@ import json -from pydantic import BaseModel, Field, field_validator +from pydantic import Field, field_validator from pydantic_core.core_schema import ValidationInfo +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.log.secrets_masker import redact diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow/api_fastapi/core_api/datamodels/dag_run.py index f3343e6c407d1..d211b0205b3b4 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -20,8 +20,9 @@ from datetime import datetime from enum import Enum -from pydantic import BaseModel, Field +from pydantic import Field +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType diff --git a/airflow/api_fastapi/core_api/datamodels/dag_sources.py b/airflow/api_fastapi/core_api/datamodels/dag_sources.py index 4b84bd1e6b1ce..6db3f334b805c 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_sources.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_sources.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class DAGSourceResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/dag_stats.py b/airflow/api_fastapi/core_api/datamodels/dag_stats.py index 0d768c2cbac07..1effdd5a94f7a 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_stats.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_stats.py @@ -17,8 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import DagRunState diff --git a/airflow/api_fastapi/core_api/datamodels/dag_warning.py b/airflow/api_fastapi/core_api/datamodels/dag_warning.py index f38a3a8d093f7..f1dbf14119160 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_warning.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_warning.py @@ -19,8 +19,7 @@ from datetime import datetime -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.dagwarning import DagWarningType diff --git a/airflow/api_fastapi/core_api/datamodels/dags.py b/airflow/api_fastapi/core_api/datamodels/dags.py index 27cc3ad473566..fc7bdcebe242b 100644 --- a/airflow/api_fastapi/core_api/datamodels/dags.py +++ b/airflow/api_fastapi/core_api/datamodels/dags.py @@ -25,12 +25,12 @@ from pendulum.tz.timezone import FixedTimezone, Timezone from pydantic import ( AliasGenerator, - BaseModel, ConfigDict, computed_field, field_validator, ) +from airflow.api_fastapi.core_api.base import BaseModel from airflow.configuration import conf from airflow.serialization.pydantic.dag import DagTagPydantic @@ -107,6 +107,17 @@ class DAGCollectionResponse(BaseModel): class DAGDetailsResponse(DAGResponse): """Specific serializer for DAG Details responses.""" + model_config = ConfigDict( + from_attributes=True, + alias_generator=AliasGenerator( + validation_alias=lambda field_name: { + "dag_run_timeout": "dagrun_timeout", + "last_parsed": "last_loaded", + "template_search_path": "template_searchpath", + }.get(field_name, field_name), + ), + ) + catchup: bool dag_run_timeout: timedelta | None asset_expression: dict | None @@ -120,16 +131,6 @@ class DAGDetailsResponse(DAGResponse): timezone: str | None last_parsed: datetime | None - model_config = ConfigDict( - alias_generator=AliasGenerator( - validation_alias=lambda field_name: { - "dag_run_timeout": "dagrun_timeout", - "last_parsed": "last_loaded", - "template_search_path": "template_searchpath", - }.get(field_name, field_name), - ) - ) - @field_validator("timezone", mode="before") @classmethod def get_timezone(cls, tz: Timezone | FixedTimezone) -> str | None: @@ -144,7 +145,7 @@ def get_params(cls, params: abc.MutableMapping | None) -> dict | None: """Convert params attribute to dict representation.""" if params is None: return None - return {param_name: param_val.dump() for param_name, param_val in params.items()} + return {k: v.dump() for k, v in params.items()} # Mypy issue https://github.com/python/mypy/issues/1362 @computed_field # type: ignore[misc] diff --git a/airflow/api_fastapi/core_api/datamodels/event_logs.py b/airflow/api_fastapi/core_api/datamodels/event_logs.py index 5b65ec85ba7b2..8ea88f363e947 100644 --- a/airflow/api_fastapi/core_api/datamodels/event_logs.py +++ b/airflow/api_fastapi/core_api/datamodels/event_logs.py @@ -19,12 +19,16 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel class EventLogResponse(BaseModel): """Event Log Response.""" + model_config = ConfigDict(populate_by_name=True, from_attributes=True) + id: int = Field(alias="event_log_id") dttm: datetime = Field(alias="when") dag_id: str | None @@ -37,8 +41,6 @@ class EventLogResponse(BaseModel): owner: str | None extra: str | None - model_config = ConfigDict(populate_by_name=True) - class EventLogCollectionResponse(BaseModel): """Event Log Collection Response.""" diff --git a/airflow/api_fastapi/core_api/datamodels/import_error.py b/airflow/api_fastapi/core_api/datamodels/import_error.py index ebc65e23eccbe..32c139da1a93a 100644 --- a/airflow/api_fastapi/core_api/datamodels/import_error.py +++ b/airflow/api_fastapi/core_api/datamodels/import_error.py @@ -18,19 +18,21 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict, Field +from pydantic import ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel class ImportErrorResponse(BaseModel): """Import Error Response.""" + model_config = ConfigDict(populate_by_name=True, from_attributes=True) + id: int = Field(alias="import_error_id") timestamp: datetime filename: str stacktrace: str = Field(alias="stack_trace") - model_config = ConfigDict(populate_by_name=True) - class ImportErrorCollectionResponse(BaseModel): """Import Error Collection Response.""" diff --git a/airflow/api_fastapi/core_api/datamodels/job.py b/airflow/api_fastapi/core_api/datamodels/job.py index e4d5ceb4b4e20..9fb4a61f9dd16 100644 --- a/airflow/api_fastapi/core_api/datamodels/job.py +++ b/airflow/api_fastapi/core_api/datamodels/job.py @@ -18,14 +18,12 @@ from datetime import datetime -from pydantic import BaseModel, ConfigDict +from airflow.api_fastapi.core_api.base import BaseModel class JobResponse(BaseModel): """Job serializer for responses.""" - model_config = ConfigDict(populate_by_name=True) - id: int dag_id: str | None state: str | None diff --git a/airflow/api_fastapi/core_api/datamodels/monitor.py b/airflow/api_fastapi/core_api/datamodels/monitor.py index 0734321a45fd5..fbaf40b4e8416 100644 --- a/airflow/api_fastapi/core_api/datamodels/monitor.py +++ b/airflow/api_fastapi/core_api/datamodels/monitor.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class BaseInfoSchema(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/plugins.py b/airflow/api_fastapi/core_api/datamodels/plugins.py index cc305ed3aa887..798ba6fa85d34 100644 --- a/airflow/api_fastapi/core_api/datamodels/plugins.py +++ b/airflow/api_fastapi/core_api/datamodels/plugins.py @@ -19,8 +19,9 @@ from typing import Annotated, Any -from pydantic import BaseModel, BeforeValidator, ConfigDict, field_validator +from pydantic import BeforeValidator, ConfigDict, field_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.plugins_manager import AirflowPluginSource diff --git a/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow/api_fastapi/core_api/datamodels/pools.py index 137392094cb5d..807627c7fefe7 100644 --- a/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow/api_fastapi/core_api/datamodels/pools.py @@ -19,7 +19,9 @@ from typing import Annotated, Callable -from pydantic import BaseModel, BeforeValidator, ConfigDict, Field +from pydantic import BeforeValidator, ConfigDict, Field + +from airflow.api_fastapi.core_api.base import BaseModel def _call_function(function: Callable[[], int]) -> int: @@ -61,7 +63,7 @@ class PoolCollectionResponse(BaseModel): class PoolPatchBody(BaseModel): """Pool serializer for patch bodies.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) name: str | None = Field(default=None, alias="pool") slots: int | None = None diff --git a/airflow/api_fastapi/core_api/datamodels/providers.py b/airflow/api_fastapi/core_api/datamodels/providers.py index 4e542f19f9f8e..8b515fafd2da7 100644 --- a/airflow/api_fastapi/core_api/datamodels/providers.py +++ b/airflow/api_fastapi/core_api/datamodels/providers.py @@ -17,7 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class ProviderResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py b/airflow/api_fastapi/core_api/datamodels/task_instances.py index 0e3e19862ea9a..6e2cc376dcd0d 100644 --- a/airflow/api_fastapi/core_api/datamodels/task_instances.py +++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py @@ -22,7 +22,6 @@ from pydantic import ( AliasPath, AwareDatetime, - BaseModel, BeforeValidator, ConfigDict, Field, @@ -31,6 +30,7 @@ model_validator, ) +from airflow.api_fastapi.core_api.base import BaseModel from airflow.api_fastapi.core_api.datamodels.job import JobResponse from airflow.api_fastapi.core_api.datamodels.trigger import TriggerResponse from airflow.utils.state import TaskInstanceState @@ -39,7 +39,7 @@ class TaskInstanceResponse(BaseModel): """TaskInstance serializer for responses.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) id: str task_id: str @@ -121,11 +121,14 @@ class TaskInstancesBatchBody(BaseModel): class TaskInstanceHistoryResponse(BaseModel): """TaskInstanceHistory serializer for responses.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) task_id: str dag_id: str + + # todo: this should not be aliased; it's ambiguous with dag run's "id" - airflow 3.0 run_id: str = Field(alias="dag_run_id") + map_index: int start_date: datetime | None end_date: datetime | None diff --git a/airflow/api_fastapi/core_api/datamodels/tasks.py b/airflow/api_fastapi/core_api/datamodels/tasks.py index 9b962390cc342..0806d4453c49a 100644 --- a/airflow/api_fastapi/core_api/datamodels/tasks.py +++ b/airflow/api_fastapi/core_api/datamodels/tasks.py @@ -22,9 +22,10 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, computed_field, field_validator, model_validator +from pydantic import computed_field, field_validator, model_validator from airflow.api_fastapi.common.types import TimeDeltaWithValidation +from airflow.api_fastapi.core_api.base import BaseModel from airflow.models.mappedoperator import MappedOperator from airflow.serialization.serialized_objects import SerializedBaseOperator, encode_priority_weight_strategy from airflow.task.priority_strategy import PriorityWeightStrategy diff --git a/airflow/api_fastapi/core_api/datamodels/trigger.py b/airflow/api_fastapi/core_api/datamodels/trigger.py index eb9be97d31407..265d40ff19bfd 100644 --- a/airflow/api_fastapi/core_api/datamodels/trigger.py +++ b/airflow/api_fastapi/core_api/datamodels/trigger.py @@ -19,7 +19,9 @@ from datetime import datetime from typing import Annotated -from pydantic import BaseModel, BeforeValidator, ConfigDict +from pydantic import BeforeValidator, ConfigDict + +from airflow.api_fastapi.core_api.base import BaseModel class TriggerResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/ui/dags.py b/airflow/api_fastapi/core_api/datamodels/ui/dags.py index 8c7af4dbf4688..991f5096b3d52 100644 --- a/airflow/api_fastapi/core_api/datamodels/ui/dags.py +++ b/airflow/api_fastapi/core_api/datamodels/ui/dags.py @@ -17,8 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel - +from airflow.api_fastapi.core_api.base import BaseModel from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse from airflow.api_fastapi.core_api.datamodels.dags import DAGResponse diff --git a/airflow/api_fastapi/core_api/datamodels/ui/dashboard.py b/airflow/api_fastapi/core_api/datamodels/ui/dashboard.py index ca0b5d98986a0..ad80685882829 100644 --- a/airflow/api_fastapi/core_api/datamodels/ui/dashboard.py +++ b/airflow/api_fastapi/core_api/datamodels/ui/dashboard.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class DAGRunTypes(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow/api_fastapi/core_api/datamodels/variables.py index 9a2ce996d3a42..624500b1c775d 100644 --- a/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow/api_fastapi/core_api/datamodels/variables.py @@ -19,8 +19,9 @@ import json -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import ConfigDict, Field, model_validator +from airflow.api_fastapi.core_api.base import BaseModel from airflow.typing_compat import Self from airflow.utils.log.secrets_masker import redact @@ -28,7 +29,7 @@ class VariableResponse(BaseModel): """Variable serializer for responses.""" - model_config = ConfigDict(populate_by_name=True) + model_config = ConfigDict(populate_by_name=True, from_attributes=True) key: str val: str | None = Field(alias="value") diff --git a/airflow/api_fastapi/core_api/datamodels/version.py b/airflow/api_fastapi/core_api/datamodels/version.py index 01c4c45376f70..b29864776c6fb 100644 --- a/airflow/api_fastapi/core_api/datamodels/version.py +++ b/airflow/api_fastapi/core_api/datamodels/version.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class VersionInfo(BaseModel): diff --git a/airflow/api_fastapi/core_api/datamodels/xcom.py b/airflow/api_fastapi/core_api/datamodels/xcom.py index 186b5aad77f09..370aa651cb2c8 100644 --- a/airflow/api_fastapi/core_api/datamodels/xcom.py +++ b/airflow/api_fastapi/core_api/datamodels/xcom.py @@ -19,7 +19,9 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, field_validator +from pydantic import field_validator + +from airflow.api_fastapi.core_api.base import BaseModel class XComResponse(BaseModel): diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 64a5acf826041..5aa37c7a6f9f8 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -109,7 +109,7 @@ def get_assets( ) ) return AssetCollectionResponse( - assets=[AssetResponse.model_validate(asset, from_attributes=True) for asset in assets], + assets=assets, total_entries=total_entries, ) @@ -157,9 +157,7 @@ def get_asset_events( assets_events = session.scalars(assets_event_select) return AssetEventCollectionResponse( - asset_events=[ - AssetEventResponse.model_validate(asset, from_attributes=True) for asset in assets_events - ], + asset_events=assets_events, total_entries=total_entries, ) @@ -187,7 +185,7 @@ def create_asset_event( if not assets_event: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Asset with uri: `{body.uri}` was not found") - return AssetEventResponse.model_validate(assets_event, from_attributes=True) + return assets_event @assets_router.get( @@ -247,7 +245,7 @@ def get_asset( if asset is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Asset with uri: `{uri}` was not found") - return AssetResponse.model_validate(asset, from_attributes=True) + return AssetResponse.model_validate(asset) @assets_router.get( @@ -282,10 +280,7 @@ def get_dag_asset_queued_events( ] return QueuedEventCollectionResponse( - queued_events=[ - QueuedEventResponse.model_validate(queued_event, from_attributes=True) - for queued_event in queued_events - ], + queued_events=queued_events, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 9c5dd0895c47a..aa6f540d32791 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -70,7 +70,7 @@ def list_backfills( backfills = session.scalars(select_stmt) return BackfillCollectionResponse( - backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills], + backfills=backfills, total_entries=total_entries, ) @@ -85,7 +85,7 @@ def get_backfill( ) -> BackfillResponse: backfill = session.get(Backfill, backfill_id) if backfill: - return BackfillResponse.model_validate(backfill, from_attributes=True) + return backfill raise HTTPException(status.HTTP_404_NOT_FOUND, "Backfill not found") @@ -107,7 +107,7 @@ def pause_backfill(backfill_id, session: Annotated[Session, Depends(get_session) if b.is_paused is False: b.is_paused = True session.commit() - return BackfillResponse.model_validate(b, from_attributes=True) + return b @backfills_router.put( @@ -127,7 +127,7 @@ def unpause_backfill(backfill_id, session: Annotated[Session, Depends(get_sessio raise HTTPException(status.HTTP_409_CONFLICT, "Backfill is already completed.") if b.is_paused: b.is_paused = False - return BackfillResponse.model_validate(b, from_attributes=True) + return b @backfills_router.put( @@ -172,7 +172,7 @@ def cancel_backfill(backfill_id, session: Annotated[Session, Depends(get_session # this is in separate transaction just to avoid potential conflicts session.refresh(b) b.completed_at = timezone.utcnow() - return BackfillResponse.model_validate(b, from_attributes=True) + return b @backfills_router.post( @@ -199,7 +199,7 @@ def create_backfill( dag_run_conf=backfill_request.dag_run_conf, reprocess_behavior=backfill_request.reprocess_behavior, ) - return BackfillResponse.model_validate(backfill_obj, from_attributes=True) + return BackfillResponse.model_validate(backfill_obj) except AlreadyRunningBackfill: raise HTTPException( status_code=status.HTTP_409_CONFLICT, diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index 0716c77f9f4e8..46ebcfcf98ca1 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -78,7 +78,7 @@ def get_connection( status.HTTP_404_NOT_FOUND, f"The Connection with connection_id: `{connection_id}` was not found" ) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return connection @connections_router.get( @@ -110,9 +110,7 @@ def get_connections( connections = session.scalars(connection_select) return ConnectionCollectionResponse( - connections=[ - ConnectionResponse.model_validate(connection, from_attributes=True) for connection in connections - ], + connections=connections, total_entries=total_entries, ) @@ -142,7 +140,7 @@ def post_connection( connection = Connection(**post_body.model_dump(by_alias=True)) session.add(connection) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return connection @connections_router.patch( @@ -182,7 +180,7 @@ def patch_connection( for key, val in data.items(): setattr(connection, key, val) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return connection @connections_router.post( @@ -213,8 +211,6 @@ def test_connection( conn = Connection(**data) os.environ[conn_env_var] = conn.get_uri() test_status, test_message = conn.test_connection() - return ConnectionTestResponse.model_validate( - {"status": test_status, "message": test_message}, from_attributes=True - ) + return ConnectionTestResponse.model_validate({"status": test_status, "message": test_message}) finally: os.environ.pop(conn_env_var, None) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index d7a196eba3f35..c26650767c98a 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Annotated +from typing import Annotated, cast from fastapi import Depends, HTTPException, Query, Request, status from sqlalchemy import select @@ -38,7 +38,7 @@ datetime_range_filter_factory, ) from airflow.api_fastapi.common.router import AirflowRouter -from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse, AssetEventResponse +from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse from airflow.api_fastapi.core_api.datamodels.dag_run import ( DAGRunClearBody, DAGRunCollectionResponse, @@ -74,7 +74,7 @@ def get_dag_run( f"The DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` was not found", ) - return DAGRunResponse.model_validate(dag_run, from_attributes=True) + return dag_run @dag_run_router.delete( @@ -156,7 +156,7 @@ def patch_dag_run( dag_run = session.get(DagRun, dag_run.id) - return DAGRunResponse.model_validate(dag_run, from_attributes=True) + return dag_run @dag_run_router.get( @@ -184,9 +184,7 @@ def get_upstream_asset_events( ) events = dag_run.consumed_asset_events return AssetEventCollectionResponse( - asset_events=[ - AssetEventResponse.model_validate(asset_event, from_attributes=True) for asset_event in events - ], + asset_events=events, total_entries=len(events), ) @@ -223,9 +221,7 @@ def clear_dag_run( ) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(ti, from_attributes=True) for ti in task_instances - ], + task_instances=cast(list[TaskInstanceResponse], task_instances), total_entries=len(task_instances), ) else: @@ -237,7 +233,7 @@ def clear_dag_run( session=session, ) dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id)) - return DAGRunResponse.model_validate(dag_run_cleared, from_attributes=True) + return dag_run_cleared @dag_run_router.get("", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) @@ -297,6 +293,6 @@ def get_dag_runs( ) dag_runs = session.scalars(dag_run_select) return DAGRunCollectionResponse( - dag_runs=[DAGRunResponse.model_validate(dr, from_attributes=True) for dr in dag_runs], + dag_runs=dag_runs, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index 9ddd1439b199a..e933710bc6903 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -37,7 +37,6 @@ from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.dag_warning import ( DAGWarningCollectionResponse, - DAGWarningResponse, ) from airflow.models import DagWarning @@ -70,6 +69,6 @@ def list_dag_warnings( dag_warnings = session.scalars(dag_warnings_select) return DAGWarningCollectionResponse( - dag_warnings=[DAGWarningResponse.model_validate(w, from_attributes=True) for w in dag_warnings], + dag_warnings=dag_warnings, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 1416855b11c94..99a86508edad4 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -101,7 +101,7 @@ def get_dags( dags = session.scalars(dags_select) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + dags=dags, total_entries=total_entries, ) @@ -162,7 +162,7 @@ def get_dag(dag_id: str, session: Annotated[Session, Depends(get_session)], requ if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) - return DAGResponse.model_validate(dag_model, from_attributes=True) + return dag_model @dags_router.get( @@ -190,7 +190,7 @@ def get_dag_details( if not key.startswith("_") and not hasattr(dag_model, key): setattr(dag_model, key, value) - return DAGDetailsResponse.model_validate(dag_model, from_attributes=True) + return dag_model @dags_router.patch( @@ -227,7 +227,7 @@ def patch_dag( for key, val in data.items(): setattr(dag, key, val) - return DAGResponse.model_validate(dag, from_attributes=True) + return dag @dags_router.patch( @@ -280,7 +280,7 @@ def patch_dags( ) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(d, from_attributes=True) for d in dags], + dags=dags, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow/api_fastapi/core_api/routes/public/event_logs.py index 7d2933365a956..51feb7e22cfb2 100644 --- a/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -54,10 +54,7 @@ def get_event_log( event_log = session.scalar(select(Log).where(Log.id == event_log_id)) if event_log is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Event Log with id: `{event_log_id}` not found") - return EventLogResponse.model_validate( - event_log, - from_attributes=True, - ) + return event_log @event_logs_router.get( @@ -134,6 +131,6 @@ def get_event_logs( event_logs = session.scalars(event_logs_select) return EventLogCollectionResponse( - event_logs=[EventLogResponse.model_validate(e, from_attributes=True) for e in event_logs], + event_logs=event_logs, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index e17abfbe1ffc9..233f94df3102d 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -58,10 +58,7 @@ def get_import_error( f"The ImportError with import_error_id: `{import_error_id}` was not found", ) - return ImportErrorResponse.model_validate( - error, - from_attributes=True, - ) + return error @import_error_router.get( @@ -98,6 +95,6 @@ def get_import_errors( import_errors = session.scalars(import_errors_select) return ImportErrorCollectionResponse( - import_errors=[ImportErrorResponse.model_validate(i, from_attributes=True) for i in import_errors], + import_errors=import_errors, total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/plugins.py b/airflow/api_fastapi/core_api/routes/public/plugins.py index 717d073636c19..61268d6c5f6ec 100644 --- a/airflow/api_fastapi/core_api/routes/public/plugins.py +++ b/airflow/api_fastapi/core_api/routes/public/plugins.py @@ -17,6 +17,8 @@ from __future__ import annotations +from typing import cast + from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.plugins import PluginCollectionResponse, PluginResponse @@ -32,9 +34,6 @@ def get_plugins( ) -> PluginCollectionResponse: plugins_info = sorted(get_plugin_info(), key=lambda x: x["name"]) return PluginCollectionResponse( - plugins=[ - PluginResponse.model_validate(plugin_info) - for plugin_info in plugins_info[offset.value :][: limit.value] - ], + plugins=cast(list[PluginResponse], plugins_info[offset.value :][: limit.value]), total_entries=len(plugins_info), ) diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index 0e67994acfaab..6fe1cb3a312b3 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Annotated +from typing import Annotated, cast from fastapi import Depends, HTTPException, Query, status from fastapi.exceptions import RequestValidationError @@ -78,7 +78,7 @@ def get_pool( if pool is None: raise HTTPException(status.HTTP_404_NOT_FOUND, f"The Pool with name: `{pool_name}` was not found") - return PoolResponse.model_validate(pool, from_attributes=True) + return pool @pools_router.get( @@ -106,7 +106,7 @@ def get_pools( pools = session.scalars(pools_select) return PoolCollectionResponse( - pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], + pools=pools, total_entries=total_entries, ) @@ -155,7 +155,7 @@ def patch_pool( for key, value in data.items(): setattr(pool, key, value) - return PoolResponse.model_validate(pool, from_attributes=True) + return pool @pools_router.post( @@ -172,8 +172,7 @@ def post_pool( """Create a Pool.""" pool = Pool(**body.model_dump()) session.add(pool) - - return PoolResponse.model_validate(pool, from_attributes=True) + return pool @pools_router.post( @@ -193,6 +192,6 @@ def post_pools( pools = [Pool(**body.model_dump()) for body in body.pools] session.add_all(pools) return PoolCollectionResponse( - pools=[PoolResponse.model_validate(pool, from_attributes=True) for pool in pools], + pools=cast(list[PoolResponse], pools), total_entries=len(pools), ) diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index ae80fe779b2c6..857b03ab00e6e 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -98,7 +98,7 @@ def get_task_instance( status.HTTP_404_NOT_FOUND, "Task instance is mapped, add the map_index value to the URL" ) - return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + return task_instance @task_instances_router.get( @@ -175,10 +175,7 @@ def get_mapped_task_instances( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(task_instance, from_attributes=True) - for task_instance in task_instances - ], + task_instances=task_instances, total_entries=total_entries, ) @@ -263,7 +260,7 @@ def get_mapped_task_instance( f"The Mapped Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}`, and map_index: `{map_index}` was not found", ) - return TaskInstanceResponse.model_validate(task_instance, from_attributes=True) + return task_instance @task_instances_router.get( @@ -339,7 +336,7 @@ def get_task_instances( ) task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], + task_instances=task_instances, total_entries=total_entries, ) @@ -415,7 +412,7 @@ def get_task_instances_batch( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], + task_instances=task_instances, total_entries=total_entries, ) @@ -452,7 +449,7 @@ def _query(orm_object: Base) -> TI | TIH | None: status.HTTP_404_NOT_FOUND, f"The Task Instance with dag_id: `{dag_id}`, run_id: `{dag_run_id}`, task_id: `{task_id}`, try_number: `{task_try_number}` and map_index: `{map_index}` was not found", ) - return TaskInstanceHistoryResponse.model_validate(result, from_attributes=True) + return result @task_instances_router.get( diff --git a/airflow/api_fastapi/core_api/routes/public/tasks.py b/airflow/api_fastapi/core_api/routes/public/tasks.py index be1fdc7324d8f..972fc8cdc9046 100644 --- a/airflow/api_fastapi/core_api/routes/public/tasks.py +++ b/airflow/api_fastapi/core_api/routes/public/tasks.py @@ -18,6 +18,7 @@ from __future__ import annotations from operator import attrgetter +from typing import cast from fastapi import HTTPException, Request, status @@ -53,8 +54,8 @@ def get_tasks( except AttributeError as err: raise HTTPException(status.HTTP_400_BAD_REQUEST, str(err)) return TaskCollectionResponse( - tasks=[TaskResponse.model_validate(task, from_attributes=True) for task in tasks], - total_entries=(len(tasks)), + tasks=cast(list[TaskResponse], tasks), + total_entries=len(tasks), ) @@ -76,4 +77,4 @@ def get_task(dag_id: str, task_id, request: Request) -> TaskResponse: task = dag.get_task(task_id=task_id) except TaskNotFound: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Task with id {task_id} was not found") - return TaskResponse.model_validate(task, from_attributes=True) + return cast(TaskResponse, task) diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index 541dbcb8f107a..a96aa51b5dd64 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -68,7 +68,7 @@ def get_variable( status.HTTP_404_NOT_FOUND, f"The Variable with key: `{variable_key}` was not found" ) - return VariableResponse.model_validate(variable, from_attributes=True) + return variable @variables_router.get( @@ -100,7 +100,7 @@ def get_variables( variables = session.scalars(variable_select) return VariableCollectionResponse( - variables=[VariableResponse.model_validate(variable, from_attributes=True) for variable in variables], + variables=variables, total_entries=total_entries, ) @@ -139,7 +139,7 @@ def patch_variable( data = patch_body.model_dump(exclude=non_update_fields, by_alias=True, exclude_none=True) for key, val in data.items(): setattr(variable, key, val) - return VariableResponse.model_validate(variable, from_attributes=True) + return variable @variables_router.post( @@ -155,4 +155,4 @@ def post_variable( variable = session.scalar(select(Variable).where(Variable.key == post_body.key).limit(1)) - return VariableResponse.model_validate(variable, from_attributes=True) + return variable diff --git a/airflow/api_fastapi/core_api/routes/public/version.py b/airflow/api_fastapi/core_api/routes/public/version.py index 0e784fbfd8694..9444fd5a774e8 100644 --- a/airflow/api_fastapi/core_api/routes/public/version.py +++ b/airflow/api_fastapi/core_api/routes/public/version.py @@ -31,4 +31,4 @@ def get_version() -> VersionInfo: airflow_version = airflow.__version__ git_version = get_airflow_git_version() version_info = VersionInfo(version=airflow_version, git_version=git_version) - return VersionInfo.model_validate(version_info) + return version_info diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py index ef13c927e8636..dff2933940c62 100644 --- a/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -89,6 +89,6 @@ def get_xcom_entry( item = xcom_stub if stringify: - return XComResponseString.model_validate(item, from_attributes=True) + return XComResponseString.model_validate(item) - return XComResponseNative.model_validate(item, from_attributes=True) + return XComResponseNative.model_validate(item) diff --git a/airflow/api_fastapi/core_api/routes/ui/dags.py b/airflow/api_fastapi/core_api/routes/ui/dags.py index 96b8b0c1b109f..017ef3c165701 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dags.py +++ b/airflow/api_fastapi/core_api/routes/ui/dags.py @@ -124,9 +124,9 @@ def recent_dag_runs( for row in dags_with_recent_dag_runs: dag_run, dag, *_ = row dag_id = dag.dag_id - dag_run_response = DAGRunResponse.model_validate(dag_run, from_attributes=True) + dag_run_response = DAGRunResponse.model_validate(dag_run) if dag_id not in dag_runs_by_dag_id: - dag_response = DAGResponse.model_validate(dag, from_attributes=True) + dag_response = DAGResponse.model_validate(dag) dag_runs_by_dag_id[dag_id] = DAGWithLatestDagRunsResponse.model_validate( { **dag_response.dict(), diff --git a/airflow/api_fastapi/core_api/routes/ui/dashboard.py b/airflow/api_fastapi/core_api/routes/ui/dashboard.py index 9462d7ee2f7c0..24682fa0c17c9 100644 --- a/airflow/api_fastapi/core_api/routes/ui/dashboard.py +++ b/airflow/api_fastapi/core_api/routes/ui/dashboard.py @@ -97,4 +97,4 @@ def historical_metrics( }, } - return HistoricalMetricDataResponse.model_validate(historical_metrics_response, from_attributes=True) + return HistoricalMetricDataResponse.model_validate(historical_metrics_response) diff --git a/airflow/api_fastapi/execution_api/datamodels/connection.py b/airflow/api_fastapi/execution_api/datamodels/connection.py index f3c678952982e..e2641417f5669 100644 --- a/airflow/api_fastapi/execution_api/datamodels/connection.py +++ b/airflow/api_fastapi/execution_api/datamodels/connection.py @@ -17,7 +17,9 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel class ConnectionResponse(BaseModel): diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 07066eb5a5cc3..a2be682cd60d9 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -20,17 +20,16 @@ import uuid from typing import Annotated, Literal, Union -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, WithJsonSchema +from pydantic import Discriminator, Tag, WithJsonSchema from airflow.api_fastapi.common.types import UtcDateTime +from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import IntermediateTIState, TaskInstanceState as TIState, TerminalTIState class TIEnterRunningPayload(BaseModel): """Schema for updating TaskInstance to 'RUNNING' state with minimal required fields.""" - model_config = ConfigDict(from_attributes=True) - state: Annotated[ Literal[TIState.RUNNING], # Specify a default in the schema, but not in code, so Pydantic marks it as required. diff --git a/airflow/api_fastapi/execution_api/datamodels/token.py b/airflow/api_fastapi/execution_api/datamodels/token.py index 7086c39813e33..568fdcf592a8a 100644 --- a/airflow/api_fastapi/execution_api/datamodels/token.py +++ b/airflow/api_fastapi/execution_api/datamodels/token.py @@ -17,7 +17,7 @@ from __future__ import annotations -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel # TODO: This is a placeholder for Task Identity Token schema. diff --git a/airflow/api_fastapi/execution_api/datamodels/variable.py b/airflow/api_fastapi/execution_api/datamodels/variable.py index 6819286f54bf6..548d593476671 100644 --- a/airflow/api_fastapi/execution_api/datamodels/variable.py +++ b/airflow/api_fastapi/execution_api/datamodels/variable.py @@ -17,13 +17,13 @@ from __future__ import annotations -from pydantic import BaseModel, ConfigDict, Field +from pydantic import Field + +from airflow.api_fastapi.core_api.base import BaseModel class VariableResponse(BaseModel): """Variable schema for responses with fields that are needed for Runtime.""" - model_config = ConfigDict(from_attributes=True) - key: str val: str | None = Field(alias="value") diff --git a/airflow/api_fastapi/execution_api/datamodels/xcom.py b/airflow/api_fastapi/execution_api/datamodels/xcom.py index 6fb6c14629e37..1f913f9ac380e 100644 --- a/airflow/api_fastapi/execution_api/datamodels/xcom.py +++ b/airflow/api_fastapi/execution_api/datamodels/xcom.py @@ -19,7 +19,7 @@ from typing import Any -from pydantic import BaseModel +from airflow.api_fastapi.core_api.base import BaseModel class XComResponse(BaseModel): diff --git a/airflow/api_fastapi/execution_api/routes/connections.py b/airflow/api_fastapi/execution_api/routes/connections.py index 86f94f5ef3f8e..ed72522ee8e24 100644 --- a/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow/api_fastapi/execution_api/routes/connections.py @@ -66,7 +66,7 @@ def get_connection( "message": f"Connection with ID {connection_id} not found", }, ) - return ConnectionResponse.model_validate(connection, from_attributes=True) + return ConnectionResponse.model_validate(connection) def has_connection_access(connection_id: str, token: TIToken) -> bool: