From 22d1406af24c3741134ac4fb9f2f49e9a41d03eb Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:10:23 -0800 Subject: [PATCH] Make filters param optional and fix typing (#44226) Given that sometimes we don't want to apply any filters, it makes sense to make the param optional. I also fix the typing on `paginated_select`. --- airflow/api_fastapi/common/db/common.py | 56 +++++++++++++++---- .../core_api/routes/public/assets.py | 6 +- .../core_api/routes/public/backfills.py | 4 +- .../core_api/routes/public/connections.py | 1 - .../core_api/routes/public/dag_run.py | 3 +- .../core_api/routes/public/dag_warning.py | 6 +- .../core_api/routes/public/dags.py | 8 +-- .../core_api/routes/public/event_logs.py | 9 +-- .../core_api/routes/public/import_error.py | 5 +- .../core_api/routes/public/pools.py | 1 - .../core_api/routes/public/task_instances.py | 14 +---- .../core_api/routes/public/variables.py | 1 - 12 files changed, 59 insertions(+), 55 deletions(-) diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index e083cf650fd8d..17da1eafacc93 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -14,10 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +Database helpers for Airflow REST API. + +:meta private: +""" from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Literal, Sequence, overload from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, create_session, provide_session @@ -47,30 +52,59 @@ def your_route(session: Annotated[Session, Depends(get_session)]): yield session -def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select: - base_select = base_select - for filter in filters: - if filter is None: +def apply_filters_to_select( + *, base_select: Select, filters: Sequence[BaseParam | None] | None = None +) -> Select: + if filters is None: + return base_select + for f in filters: + if f is None: continue - base_select = filter.to_orm(base_select) + base_select = f.to_orm(base_select) return base_select +@overload +def paginated_select( + *, + select: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: Session = NEW_SESSION, + return_total_entries: Literal[True] = True, +) -> tuple[Select, int]: ... + + +@overload +def paginated_select( + *, + select: Select, + filters: Sequence[BaseParam] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: Session = NEW_SESSION, + return_total_entries: Literal[False], +) -> tuple[Select, None]: ... + + @provide_session def paginated_select( *, select: Select, - filters: Sequence[BaseParam], + filters: Sequence[BaseParam] | None = None, order_by: BaseParam | None = None, offset: BaseParam | None = None, limit: BaseParam | None = None, session: Session = NEW_SESSION, return_total_entries: bool = True, -) -> Select: +) -> tuple[Select, int | None]: base_select = apply_filters_to_select( - select, - filters, + base_select=select, + filters=filters, ) total_entries = None @@ -82,6 +116,6 @@ def paginated_select( # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) - base_select = apply_filters_to_select(base_select, [order_by, offset, limit]) + base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit]) return base_select, total_entries diff --git a/airflow/api_fastapi/core_api/routes/public/assets.py b/airflow/api_fastapi/core_api/routes/public/assets.py index 7da7fcb8e878d..64a5acf826041 100644 --- a/airflow/api_fastapi/core_api/routes/public/assets.py +++ b/airflow/api_fastapi/core_api/routes/public/assets.py @@ -102,6 +102,7 @@ def get_assets( limit=limit, session=session, ) + assets = session.scalars( assets_select.options( subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks) @@ -211,7 +212,7 @@ def get_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[]) + dag_asset_queued_events_select, total_entries = paginated_select(select=query) adrqs = session.execute(dag_asset_queued_events_select).all() if not adrqs: @@ -270,9 +271,8 @@ def get_dag_asset_queued_events( .where(*where_clause) ) - dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[]) + dag_asset_queued_events_select, total_entries = paginated_select(select=query) adrqs = session.execute(dag_asset_queued_events_select).all() - if not adrqs: raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with dag_id: `{dag_id}` was not found") diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index c4ab7ce16b603..9c5dd0895c47a 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -61,16 +61,16 @@ def list_backfills( ) -> BackfillCollectionResponse: select_stmt, total_entries = paginated_select( select=select(Backfill).where(Backfill.dag_id == dag_id), - filters=[], order_by=order_by, offset=offset, limit=limit, session=session, ) + backfills = session.scalars(select_stmt) return BackfillCollectionResponse( - backfills=[BackfillResponse.model_validate(x, from_attributes=True) for x in backfills], + backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow/api_fastapi/core_api/routes/public/connections.py index 1ca158bad5dac..0716c77f9f4e8 100644 --- a/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow/api_fastapi/core_api/routes/public/connections.py @@ -101,7 +101,6 @@ def get_connections( """Get all connection entries.""" connection_select, total_entries = paginated_select( select=select(Connection), - filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 18a3128a8047e..d7a196eba3f35 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -295,9 +295,8 @@ def get_dag_runs( limit=limit, session=session, ) - dag_runs = session.scalars(dag_run_select) return DAGRunCollectionResponse( - dag_runs=[DAGRunResponse.model_validate(dag_run, from_attributes=True) for dag_run in dag_runs], + dag_runs=[DAGRunResponse.model_validate(dr, from_attributes=True) for dr in dag_runs], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index 3560874ad1aa9..9ddd1439b199a 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -67,13 +67,9 @@ def list_dag_warnings( limit=limit, session=session, ) - dag_warnings = session.scalars(dag_warnings_select) return DAGWarningCollectionResponse( - dag_warnings=[ - DAGWarningResponse.model_validate(dag_warning, from_attributes=True) - for dag_warning in dag_warnings - ], + dag_warnings=[DAGWarningResponse.model_validate(w, from_attributes=True) for w in dag_warnings], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dags.py b/airflow/api_fastapi/core_api/routes/public/dags.py index 0383584fe1729..1416855b11c94 100644 --- a/airflow/api_fastapi/core_api/routes/public/dags.py +++ b/airflow/api_fastapi/core_api/routes/public/dags.py @@ -135,7 +135,7 @@ def get_dag_tags( session=session, ) dag_tags = session.execute(dag_tags_select).scalars().all() - return DAGTagCollectionResponse(tags=[dag_tag for dag_tag in dag_tags], total_entries=total_entries) + return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries) @dags_router.get( @@ -259,6 +259,7 @@ def patch_dags( status.HTTP_400_BAD_REQUEST, "Only `is_paused` field can be updated through the REST API" ) else: + # todo: this is not used? update_mask = ["is_paused"] dags_select, total_entries = paginated_select( @@ -269,11 +270,8 @@ def patch_dags( limit=limit, session=session, ) - dags = session.scalars(dags_select).all() - dags_to_update = {dag.dag_id for dag in dags} - session.execute( update(DagModel) .where(DagModel.dag_id.in_(dags_to_update)) @@ -282,7 +280,7 @@ def patch_dags( ) return DAGCollectionResponse( - dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags], + dags=[DAGResponse.model_validate(d, from_attributes=True) for d in dags], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/event_logs.py b/airflow/api_fastapi/core_api/routes/public/event_logs.py index 166b48995399a..7d2933365a956 100644 --- a/airflow/api_fastapi/core_api/routes/public/event_logs.py +++ b/airflow/api_fastapi/core_api/routes/public/event_logs.py @@ -126,7 +126,6 @@ def get_event_logs( base_select = base_select.where(Log.dttm > after) event_logs_select, total_entries = paginated_select( select=base_select, - filters=[], order_by=order_by, offset=offset, limit=limit, @@ -135,12 +134,6 @@ def get_event_logs( event_logs = session.scalars(event_logs_select) return EventLogCollectionResponse( - event_logs=[ - EventLogResponse.model_validate( - event_log, - from_attributes=True, - ) - for event_log in event_logs - ], + event_logs=[EventLogResponse.model_validate(e, from_attributes=True) for e in event_logs], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/import_error.py b/airflow/api_fastapi/core_api/routes/public/import_error.py index 29f783081d78c..e17abfbe1ffc9 100644 --- a/airflow/api_fastapi/core_api/routes/public/import_error.py +++ b/airflow/api_fastapi/core_api/routes/public/import_error.py @@ -90,7 +90,6 @@ def get_import_errors( """Get all import errors.""" import_errors_select, total_entries = paginated_select( select=select(ParseImportError), - filters=[], order_by=order_by, offset=offset, limit=limit, @@ -99,8 +98,6 @@ def get_import_errors( import_errors = session.scalars(import_errors_select) return ImportErrorCollectionResponse( - import_errors=[ - ImportErrorResponse.model_validate(error, from_attributes=True) for error in import_errors - ], + import_errors=[ImportErrorResponse.model_validate(i, from_attributes=True) for i in import_errors], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/pools.py b/airflow/api_fastapi/core_api/routes/public/pools.py index df14b9aae5a05..582e03ab00dbd 100644 --- a/airflow/api_fastapi/core_api/routes/public/pools.py +++ b/airflow/api_fastapi/core_api/routes/public/pools.py @@ -96,7 +96,6 @@ def get_pools( """Get all pools entries.""" pools_select, total_entries = paginated_select( select=select(Pool), - filters=[], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 6d5b427abb1dc..ed6d46dc78f02 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -169,7 +169,6 @@ def get_mapped_task_instances( limit=limit, session=session, ) - task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( @@ -335,14 +334,9 @@ def get_task_instances( limit=limit, session=session, ) - task_instances = session.scalars(task_instance_select) - return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(task_instance, from_attributes=True) - for task_instance in task_instances - ], + task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], total_entries=total_entries, ) @@ -411,7 +405,6 @@ def get_task_instances_batch( limit=limit, session=session, ) - task_instance_select = task_instance_select.options( joinedload(TI.rendered_task_instance_fields), joinedload(TI.task_instance_note) ) @@ -419,10 +412,7 @@ def get_task_instances_batch( task_instances = session.scalars(task_instance_select) return TaskInstanceCollectionResponse( - task_instances=[ - TaskInstanceResponse.model_validate(task_instance, from_attributes=True) - for task_instance in task_instances - ], + task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances], total_entries=total_entries, ) diff --git a/airflow/api_fastapi/core_api/routes/public/variables.py b/airflow/api_fastapi/core_api/routes/public/variables.py index a9e479f4f853a..541dbcb8f107a 100644 --- a/airflow/api_fastapi/core_api/routes/public/variables.py +++ b/airflow/api_fastapi/core_api/routes/public/variables.py @@ -91,7 +91,6 @@ def get_variables( """Get all Variables entries.""" variable_select, total_entries = paginated_select( select=select(Variable), - filters=[], order_by=order_by, offset=offset, limit=limit,