Skip to content

Commit

Permalink
AIP-84 Migrate GET Dag Runs endpoint to FastAPI (apache#43506)
Browse files Browse the repository at this point in the history
* add list_dag_runs

* use logical_date

* add tests

* wip - writing tests

* add tests

* fix tests

* Update airflow/api_fastapi/core_api/routes/public/dag_run.py

* add status

* Small tweak

---------

Co-authored-by: pierrejeambrun <[email protected]>
  • Loading branch information
rawwar and pierrejeambrun authored Nov 20, 2024
1 parent 8831b31 commit 175b960
Show file tree
Hide file tree
Showing 13 changed files with 1,039 additions and 13 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def _fetch_dag_runs(
return session.scalars(query.offset(offset).limit(limit)).all(), total_entries


@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@format_parameters(
{
Expand Down
40 changes: 39 additions & 1 deletion airflow/api_fastapi/common/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@

from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, List, Optional, TypeVar, Union, overload
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Generic,
Iterable,
List,
Optional,
TypeVar,
Union,
overload,
)

from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
Expand Down Expand Up @@ -211,6 +223,7 @@ class SortParam(BaseParam[str]):
"last_run_start_date": DagRun.start_date,
"connection_id": Connection.conn_id,
"import_error_id": ParseImportError.id,
"dag_run_id": DagRun.run_id,
}

def __init__(
Expand Down Expand Up @@ -309,6 +322,30 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil
return self.set_value(owners)


class DagRunStateFilter(BaseParam[List[Optional[DagRunState]]]):
"""Filter on Dag Run state."""

def to_orm(self, select: Select) -> Select:
if self.skip_none is False:
raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}")

if not self.value:
return select

conditions = [DagRun.state == state for state in self.value]
return select.where(or_(*conditions))

@staticmethod
def _convert_dag_run_states(states: Iterable[str] | None) -> list[DagRunState | None] | None:
if not states:
return None
return [None if s in ("none", None) else DagRunState(s) for s in states]

def depends(self, state: list[str] = Query(default_factory=list)) -> DagRunStateFilter:
states = self._convert_dag_run_states(state)
return self.set_value(states)


class TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
"""Filter on task instance state."""

Expand Down Expand Up @@ -656,6 +693,7 @@ def depends_float(
# DagRun
QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)]
QueryDagIdsFilter = Annotated[DagIdsFilter, Depends(DagIdsFilter(DagRun).depends)]
QueryDagRunStateFilter = Annotated[DagRunStateFilter, Depends(DagRunStateFilter().depends)]

# DAGWarning
QueryDagIdInDagWarningFilter = Annotated[_DagIdFilter, Depends(_DagIdFilter(DagWarning.dag_id).depends)]
Expand Down
7 changes: 7 additions & 0 deletions airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,10 @@ class DAGRunResponse(BaseModel):
triggered_by: DagRunTriggeredByType
conf: dict
note: str | None


class DAGRunCollectionResponse(BaseModel):
"""DAG Run Collection serializer for responses."""

dag_runs: list[DAGRunResponse]
total_entries: int
168 changes: 168 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,158 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns:
get:
tags:
- DagRun
summary: Get Dag Runs
description: 'Get all DAG Runs.
This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for
all DAGs.'
operationId: get_dag_runs
parameters:
- name: dag_id
in: path
required: true
schema:
type: string
title: Dag Id
- name: limit
in: query
required: false
schema:
type: integer
minimum: 0
default: 100
title: Limit
- name: offset
in: query
required: false
schema:
type: integer
minimum: 0
default: 0
title: Offset
- name: logical_date_gte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Logical Date Gte
- name: logical_date_lte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Logical Date Lte
- name: start_date_gte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Start Date Gte
- name: start_date_lte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Start Date Lte
- name: end_date_gte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: End Date Gte
- name: end_date_lte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: End Date Lte
- name: updated_at_gte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Updated At Gte
- name: updated_at_lte
in: query
required: false
schema:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Updated At Lte
- name: state
in: query
required: false
schema:
type: array
items:
type: string
title: State
- name: order_by
in: query
required: false
schema:
type: string
default: id
title: Order By
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRunCollectionResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dagSources/{dag_id}:
get:
tags:
Expand Down Expand Up @@ -5399,6 +5551,22 @@ components:
type: object
title: DAGRunClearBody
description: DAG Run serializer for clear endpoint body.
DAGRunCollectionResponse:
properties:
dag_runs:
items:
$ref: '#/components/schemas/DAGRunResponse'
type: array
title: Dag Runs
total_entries:
type: integer
title: Total Entries
type: object
required:
- dag_runs
- total_entries
title: DAGRunCollectionResponse
description: DAG Run Collection serializer for responses.
DAGRunPatchBody:
properties:
state:
Expand Down
74 changes: 73 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,20 @@
set_dag_run_state_to_queued,
set_dag_run_state_to_success,
)
from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.parameters import (
QueryDagRunStateFilter,
QueryLimit,
QueryOffset,
RangeFilter,
SortParam,
datetime_range_filter_factory,
)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse, AssetEventResponse
from airflow.api_fastapi.core_api.datamodels.dag_run import (
DAGRunClearBody,
DAGRunCollectionResponse,
DAGRunPatchBody,
DAGRunPatchStates,
DAGRunResponse,
Expand Down Expand Up @@ -229,3 +238,66 @@ def clear_dag_run(
)
dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id))
return DAGRunResponse.model_validate(dag_run_cleared, from_attributes=True)


@dag_run_router.get("", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]))
def get_dag_runs(
dag_id: str,
limit: QueryLimit,
offset: QueryOffset,
logical_date: Annotated[RangeFilter, Depends(datetime_range_filter_factory("logical_date", DagRun))],
start_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("start_date", DagRun))],
end_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("end_date", DagRun))],
update_at_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("updated_at", DagRun))],
state: QueryDagRunStateFilter,
order_by: Annotated[
SortParam,
Depends(
SortParam(
[
"id",
"state",
"dag_id",
"logical_date",
"dag_run_id",
"start_date",
"end_date",
"updated_at",
"external_trigger",
"conf",
],
DagRun,
).dynamic_depends(default="id")
),
],
session: Annotated[Session, Depends(get_session)],
request: Request,
) -> DAGRunCollectionResponse:
"""
Get all DAG Runs.
This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs.
"""
base_query = select(DagRun)

if dag_id != "~":
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found")

base_query = base_query.filter(DagRun.dag_id == dag_id)

dag_run_select, total_entries = paginated_select(
base_query,
[logical_date, start_date_range, end_date_range, update_at_range, state],
order_by,
offset,
limit,
session,
)

dag_runs = session.scalars(dag_run_select)
return DAGRunCollectionResponse(
dag_runs=[DAGRunResponse.model_validate(dag_run, from_attributes=True) for dag_run in dag_runs],
total_entries=total_entries,
)
Loading

0 comments on commit 175b960

Please sign in to comment.