Skip to content

Commit

Permalink
AIP-84 Refactor Filter Query Parameters (apache#43947)
Browse files Browse the repository at this point in the history
* Add FilterParam and type_filter_param_factory

* Refactor Get Event Logs with filter_param_factory

* Refactor add type option for filter_param_factory

* Fix Get Event Logs with latest paginated_select

* Refactor Get Assets Event

* Refactor List Dag Warnings

* Refactor DagRun related

- QueryLastDagRunStateFilter
- dag_ids of get_dag_stats

* Remove unused parameters

* Refactor on Dag parameters

* Add any_equal to FilterParam

* Refactor Task Instance

* Fix Get Event Logs type

* Fix after rebase

* Refactor with search_param_factory

* Refactor QueryLastDagRunStateFilter

* Fix get_list_dag_runs_batch
  • Loading branch information
jason810496 authored Dec 3, 2024
1 parent d059d4a commit 31ba41e
Show file tree
Hide file tree
Showing 13 changed files with 339 additions and 524 deletions.
597 changes: 191 additions & 406 deletions airflow/api_fastapi/common/parameters.py

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,29 @@ paths:
description: Get all Event Logs.
operationId: get_event_logs
parameters:
- name: limit
in: query
required: false
schema:
type: integer
minimum: 0
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
minimum: 0
default: 0
title: Offset
- name: order_by
in: query
required: false
schema:
type: string
default: id
title: Order By
- name: dag_id
in: query
required: false
Expand Down Expand Up @@ -3063,29 +3086,6 @@ paths:
format: date-time
- type: 'null'
title: After
- name: limit
in: query
required: false
schema:
type: integer
minimum: 0
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
minimum: 0
default: 0
title: Offset
- name: order_by
in: query
required: false
schema:
type: string
default: id
title: Order By
responses:
'200':
description: Successful Response
Expand Down
27 changes: 17 additions & 10 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,14 @@

from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import (
FilterParam,
OptionalDateTimeQuery,
QueryAssetDagIdPatternSearch,
QueryAssetIdFilter,
QueryLimit,
QueryOffset,
QuerySourceDagIdFilter,
QuerySourceMapIndexFilter,
QuerySourceRunIdFilter,
QuerySourceTaskIdFilter,
QueryUriPatternSearch,
SortParam,
filter_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.assets import (
Expand Down Expand Up @@ -135,11 +132,21 @@ def get_asset_events(
).dynamic_depends("timestamp")
),
],
asset_id: QueryAssetIdFilter,
source_dag_id: QuerySourceDagIdFilter,
source_task_id: QuerySourceTaskIdFilter,
source_run_id: QuerySourceRunIdFilter,
source_map_index: QuerySourceMapIndexFilter,
asset_id: Annotated[
FilterParam[int | None], Depends(filter_param_factory(AssetEvent.asset_id, int | None))
],
source_dag_id: Annotated[
FilterParam[str | None], Depends(filter_param_factory(AssetEvent.source_dag_id, str | None))
],
source_task_id: Annotated[
FilterParam[str | None], Depends(filter_param_factory(AssetEvent.source_task_id, str | None))
],
source_run_id: Annotated[
FilterParam[str | None], Depends(filter_param_factory(AssetEvent.source_run_id, str | None))
],
source_map_index: Annotated[
FilterParam[int | None], Depends(filter_param_factory(AssetEvent.source_map_index, int | None))
],
session: SessionDep,
) -> AssetEventCollectionResponse:
"""Get asset events."""
Expand Down
8 changes: 4 additions & 4 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
)
from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import (
DagIdsFilter,
FilterOptionEnum,
FilterParam,
LimitFilter,
OffsetFilter,
QueryDagRunStateFilter,
Expand Down Expand Up @@ -374,7 +375,7 @@ def get_list_dag_runs_batch(
dag_id: Literal["~"], body: DAGRunsBatchBody, session: SessionDep
) -> DAGRunCollectionResponse:
"""Get a list of DAG Runs."""
dag_ids = DagIdsFilter(DagRun, body.dag_ids)
dag_ids = FilterParam(DagRun.dag_id, body.dag_ids, FilterOptionEnum.IN)
logical_date = RangeFilter(
Range(lower_bound=body.logical_date_gte, upper_bound=body.logical_date_lte),
attribute=DagRun.logical_date,
Expand All @@ -387,8 +388,7 @@ def get_list_dag_runs_batch(
Range(lower_bound=body.end_date_gte, upper_bound=body.end_date_lte),
attribute=DagRun.end_date,
)

state = QueryDagRunStateFilter(body.states)
state = FilterParam(DagRun.state, body.states, FilterOptionEnum.ANY_EQUAL)

offset = OffsetFilter(body.page_offset)
limit = LimitFilter(body.page_limit)
Expand Down
16 changes: 13 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/dag_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,28 @@

from __future__ import annotations

from fastapi import status
from typing import Annotated

from fastapi import Depends, status

from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
)
from airflow.api_fastapi.common.db.dag_runs import dagruns_select_with_state_count
from airflow.api_fastapi.common.parameters import QueryDagIdsFilter
from airflow.api_fastapi.common.parameters import (
FilterOptionEnum,
FilterParam,
filter_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.dag_stats import (
DagStatsCollectionResponse,
DagStatsResponse,
DagStatsStateResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.models.dagrun import DagRun
from airflow.utils.state import DagRunState

dag_stats_router = AirflowRouter(tags=["DagStats"], prefix="/dagStats")
Expand All @@ -48,7 +55,10 @@
)
def get_dag_stats(
session: SessionDep,
dag_ids: QueryDagIdsFilter,
dag_ids: Annotated[
FilterParam[list[str]],
Depends(filter_param_factory(DagRun.dag_id, list[str], FilterOptionEnum.IN, "dag_ids")),
],
) -> DagStatsCollectionResponse:
"""Get Dag statistics."""
dagruns_select, _ = paginated_select(
Expand Down
13 changes: 8 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/dag_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@
paginated_select,
)
from airflow.api_fastapi.common.parameters import (
QueryDagIdInDagWarningFilter,
FilterParam,
QueryLimit,
QueryOffset,
QueryWarningTypeFilter,
SortParam,
filter_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.dag_warning import (
DAGWarningCollectionResponse,
)
from airflow.models import DagWarning
from airflow.models.dagwarning import DagWarning, DagWarningType

dag_warning_router = AirflowRouter(tags=["DagWarning"])

Expand All @@ -46,8 +46,11 @@
"/dagWarnings",
)
def list_dag_warnings(
dag_id: QueryDagIdInDagWarningFilter,
warning_type: QueryWarningTypeFilter,
dag_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(DagWarning.dag_id, str | None))],
warning_type: Annotated[
FilterParam[DagWarningType | None],
Depends(filter_param_factory(DagWarning.warning_type, DagWarningType | None)),
],
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
Expand Down
77 changes: 42 additions & 35 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
from datetime import datetime
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
from fastapi import Depends, HTTPException, status
from sqlalchemy import select

from airflow.api_fastapi.common.db.common import (
SessionDep,
paginated_select,
)
from airflow.api_fastapi.common.parameters import (
FilterOptionEnum,
FilterParam,
QueryLimit,
QueryOffset,
SortParam,
filter_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.event_logs import (
Expand Down Expand Up @@ -83,46 +86,50 @@ def get_event_logs(
).dynamic_depends()
),
],
dag_id: str | None = None,
task_id: str | None = None,
run_id: str | None = None,
map_index: int | None = None,
try_number: int | None = None,
owner: str | None = None,
event: str | None = None,
excluded_events: list[str] | None = Query(None),
included_events: list[str] | None = Query(None),
before: datetime | None = None,
after: datetime | None = None,
dag_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(Log.dag_id, str | None))],
task_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(Log.task_id, str | None))],
run_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(Log.run_id, str | None))],
map_index: Annotated[FilterParam[int | None], Depends(filter_param_factory(Log.map_index, int | None))],
try_number: Annotated[FilterParam[int | None], Depends(filter_param_factory(Log.try_number, int | None))],
owner: Annotated[FilterParam[str | None], Depends(filter_param_factory(Log.owner, str | None))],
event: Annotated[FilterParam[str | None], Depends(filter_param_factory(Log.event, str | None))],
excluded_events: Annotated[
FilterParam[list[str] | None],
Depends(
filter_param_factory(Log.event, list[str] | None, FilterOptionEnum.NOT_IN, "excluded_events")
),
],
included_events: Annotated[
FilterParam[list[str] | None],
Depends(filter_param_factory(Log.event, list[str] | None, FilterOptionEnum.IN, "included_events")),
],
before: Annotated[
FilterParam[datetime | None],
Depends(filter_param_factory(Log.dttm, datetime | None, FilterOptionEnum.LESS_THAN, "before")),
],
after: Annotated[
FilterParam[datetime | None],
Depends(filter_param_factory(Log.dttm, datetime | None, FilterOptionEnum.GREATER_THAN, "after")),
],
) -> EventLogCollectionResponse:
"""Get all Event Logs."""
query = select(Log).group_by(Log.id)
# TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11`
if dag_id is not None:
query = query.where(Log.dag_id == dag_id)
if task_id is not None:
query = query.where(Log.task_id == task_id)
if run_id is not None:
query = query.where(Log.run_id == run_id)
if map_index is not None:
query = query.where(Log.map_index == map_index)
if try_number is not None:
query = query.where(Log.try_number == try_number)
if owner is not None:
query = query.where(Log.owner == owner)
if event is not None:
query = query.where(Log.event == event)
if excluded_events is not None:
query = query.where(Log.event.notin_(excluded_events))
if included_events is not None:
query = query.where(Log.event.in_(included_events))
if before is not None:
query = query.where(Log.dttm < before)
if after is not None:
query = query.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
statement=query,
order_by=order_by,
filters=[
dag_id,
task_id,
run_id,
map_index,
try_number,
owner,
event,
excluded_events,
included_events,
before,
after,
],
offset=offset,
limit=limit,
session=session,
Expand Down
26 changes: 17 additions & 9 deletions airflow/api_fastapi/core_api/routes/public/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,13 @@
paginated_select,
)
from airflow.api_fastapi.common.parameters import (
QueryJobExecutorClassFilter,
QueryJobHostnameFilter,
QueryJobStateFilter,
QueryJobTypeFilter,
FilterParam,
QueryLimit,
QueryOffset,
RangeFilter,
SortParam,
datetime_range_filter_factory,
filter_param_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.job import (
Expand Down Expand Up @@ -84,15 +82,25 @@ def get_jobs(
),
],
session: SessionDep,
state: QueryJobStateFilter,
job_type: QueryJobTypeFilter,
hostname: QueryJobHostnameFilter,
executor_class: QueryJobExecutorClassFilter,
state: Annotated[
FilterParam[str | None], Depends(filter_param_factory(Job.state, str | None, filter_name="job_state"))
],
job_type: Annotated[
FilterParam[str | None],
Depends(filter_param_factory(Job.job_type, str | None, filter_name="job_type")),
],
hostname: Annotated[
FilterParam[str | None],
Depends(filter_param_factory(Job.hostname, str | None, filter_name="hostname")),
],
executor_class: Annotated[
FilterParam[str | None],
Depends(filter_param_factory(Job.executor_class, str | None, filter_name="executor_class")),
],
is_alive: bool | None = None,
) -> JobCollectionResponse:
"""Get all jobs."""
base_select = select(Job).where(Job.state == JobState.RUNNING).order_by(Job.latest_heartbeat.desc())
# TODO: Refactor using the `FilterParam` class in commit `574b72e41cc5ed175a2bbf4356522589b836bb11`

jobs_select, total_entries = paginated_select(
statement=base_select,
Expand Down
Loading

0 comments on commit 31ba41e

Please sign in to comment.