From 175b960ce0dd5530d96c2ed3545d10cfdc07b8f1 Mon Sep 17 00:00:00 2001 From: Kalyan R Date: Wed, 20 Nov 2024 23:57:39 +0530 Subject: [PATCH] AIP-84 Migrate GET Dag Runs endpoint to FastAPI (#43506) * add list_dag_runs * use logical_date * add tests * wip - writing tests * add tests * fix tests * Update airflow/api_fastapi/core_api/routes/public/dag_run.py * add status * Small tweak --------- Co-authored-by: pierrejeambrun --- .../endpoints/dag_run_endpoint.py | 1 + airflow/api_fastapi/common/parameters.py | 40 ++- .../core_api/datamodels/dag_run.py | 7 + .../core_api/openapi/v1-generated.yaml | 168 ++++++++++ .../core_api/routes/public/dag_run.py | 74 ++++- airflow/ui/openapi-gen/queries/common.ts | 59 ++++ airflow/ui/openapi-gen/queries/prefetch.ts | 87 ++++++ airflow/ui/openapi-gen/queries/queries.ts | 96 ++++++ airflow/ui/openapi-gen/queries/suspense.ts | 96 ++++++ .../ui/openapi-gen/requests/schemas.gen.ts | 20 ++ .../ui/openapi-gen/requests/services.gen.ts | 56 ++++ airflow/ui/openapi-gen/requests/types.gen.ts | 53 ++++ .../core_api/routes/public/test_dag_run.py | 295 +++++++++++++++++- 13 files changed, 1039 insertions(+), 13 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index b8e7f36d1fd43..00dd8ca907193 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -192,6 +192,7 @@ def _fetch_dag_runs( return session.scalars(query.offset(offset).limit(limit)).all(), total_entries +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.RUN) @format_parameters( { diff --git a/airflow/api_fastapi/common/parameters.py b/airflow/api_fastapi/common/parameters.py index 942fa80e9db72..6bfbfadf4180c 100644 --- a/airflow/api_fastapi/common/parameters.py +++ b/airflow/api_fastapi/common/parameters.py @@ -19,7 +19,19 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, List, Optional, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Generic, + Iterable, + List, + Optional, + TypeVar, + Union, + overload, +) from fastapi import Depends, HTTPException, Query from pendulum.parsing.exceptions import ParserError @@ -211,6 +223,7 @@ class SortParam(BaseParam[str]): "last_run_start_date": DagRun.start_date, "connection_id": Connection.conn_id, "import_error_id": ParseImportError.id, + "dag_run_id": DagRun.run_id, } def __init__( @@ -309,6 +322,30 @@ def depends(self, owners: list[str] = Query(default_factory=list)) -> _OwnersFil return self.set_value(owners) +class DagRunStateFilter(BaseParam[List[Optional[DagRunState]]]): + """Filter on Dag Run state.""" + + def to_orm(self, select: Select) -> Select: + if self.skip_none is False: + raise ValueError(f"Cannot set 'skip_none' to False on a {type(self)}") + + if not self.value: + return select + + conditions = [DagRun.state == state for state in self.value] + return select.where(or_(*conditions)) + + @staticmethod + def _convert_dag_run_states(states: Iterable[str] | None) -> list[DagRunState | None] | None: + if not states: + return None + return [None if s in ("none", None) else DagRunState(s) for s in states] + + def depends(self, state: list[str] = Query(default_factory=list)) -> DagRunStateFilter: + states = self._convert_dag_run_states(state) + return self.set_value(states) + + class TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]): """Filter on task instance state.""" @@ -656,6 +693,7 @@ def depends_float( # DagRun QueryLastDagRunStateFilter = Annotated[_LastDagRunStateFilter, Depends(_LastDagRunStateFilter().depends)] QueryDagIdsFilter = Annotated[DagIdsFilter, Depends(DagIdsFilter(DagRun).depends)] +QueryDagRunStateFilter = Annotated[DagRunStateFilter, Depends(DagRunStateFilter().depends)] # DAGWarning QueryDagIdInDagWarningFilter = Annotated[_DagIdFilter, Depends(_DagIdFilter(DagWarning.dag_id).depends)] diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow/api_fastapi/core_api/datamodels/dag_run.py index 8241885aff2fe..f3343e6c407d1 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -65,3 +65,10 @@ class DAGRunResponse(BaseModel): triggered_by: DagRunTriggeredByType conf: dict note: str | None + + +class DAGRunCollectionResponse(BaseModel): + """DAG Run Collection serializer for responses.""" + + dag_runs: list[DAGRunResponse] + total_entries: int diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 2d26571b1e374..2f74f2268928f 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1656,6 +1656,158 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns: + get: + tags: + - DagRun + summary: Get Dag Runs + description: 'Get all DAG Runs. + + + This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for + all DAGs.' + operationId: get_dag_runs + parameters: + - name: dag_id + in: path + required: true + schema: + type: string + title: Dag Id + - name: limit + in: query + required: false + schema: + type: integer + minimum: 0 + default: 100 + title: Limit + - name: offset + in: query + required: false + schema: + type: integer + minimum: 0 + default: 0 + title: Offset + - name: logical_date_gte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Gte + - name: logical_date_lte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Lte + - name: start_date_gte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Gte + - name: start_date_lte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Lte + - name: end_date_gte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Gte + - name: end_date_lte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Lte + - name: updated_at_gte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Updated At Gte + - name: updated_at_lte + in: query + required: false + schema: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Updated At Lte + - name: state + in: query + required: false + schema: + type: array + items: + type: string + title: State + - name: order_by + in: query + required: false + schema: + type: string + default: id + title: Order By + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dagSources/{dag_id}: get: tags: @@ -5399,6 +5551,22 @@ components: type: object title: DAGRunClearBody description: DAG Run serializer for clear endpoint body. + DAGRunCollectionResponse: + properties: + dag_runs: + items: + $ref: '#/components/schemas/DAGRunResponse' + type: array + title: Dag Runs + total_entries: + type: integer + title: Total Entries + type: object + required: + - dag_runs + - total_entries + title: DAGRunCollectionResponse + description: DAG Run Collection serializer for responses. DAGRunPatchBody: properties: state: diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index ce9c4410aeb69..6ce60fe896d1c 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -28,11 +28,20 @@ set_dag_run_state_to_queued, set_dag_run_state_to_success, ) -from airflow.api_fastapi.common.db.common import get_session +from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.parameters import ( + QueryDagRunStateFilter, + QueryLimit, + QueryOffset, + RangeFilter, + SortParam, + datetime_range_filter_factory, +) from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.assets import AssetEventCollectionResponse, AssetEventResponse from airflow.api_fastapi.core_api.datamodels.dag_run import ( DAGRunClearBody, + DAGRunCollectionResponse, DAGRunPatchBody, DAGRunPatchStates, DAGRunResponse, @@ -229,3 +238,66 @@ def clear_dag_run( ) dag_run_cleared = session.scalar(select(DagRun).where(DagRun.id == dag_run.id)) return DAGRunResponse.model_validate(dag_run_cleared, from_attributes=True) + + +@dag_run_router.get("", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) +def get_dag_runs( + dag_id: str, + limit: QueryLimit, + offset: QueryOffset, + logical_date: Annotated[RangeFilter, Depends(datetime_range_filter_factory("logical_date", DagRun))], + start_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("start_date", DagRun))], + end_date_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("end_date", DagRun))], + update_at_range: Annotated[RangeFilter, Depends(datetime_range_filter_factory("updated_at", DagRun))], + state: QueryDagRunStateFilter, + order_by: Annotated[ + SortParam, + Depends( + SortParam( + [ + "id", + "state", + "dag_id", + "logical_date", + "dag_run_id", + "start_date", + "end_date", + "updated_at", + "external_trigger", + "conf", + ], + DagRun, + ).dynamic_depends(default="id") + ), + ], + session: Annotated[Session, Depends(get_session)], + request: Request, +) -> DAGRunCollectionResponse: + """ + Get all DAG Runs. + + This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. + """ + base_query = select(DagRun) + + if dag_id != "~": + dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + if not dag: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"The DAG with dag_id: `{dag_id}` was not found") + + base_query = base_query.filter(DagRun.dag_id == dag_id) + + dag_run_select, total_entries = paginated_select( + base_query, + [logical_date, start_date_range, end_date_range, update_at_range, state], + order_by, + offset, + limit, + session, + ) + + dag_runs = session.scalars(dag_run_select) + return DAGRunCollectionResponse( + dag_runs=[DAGRunResponse.model_validate(dag_run, from_attributes=True) for dag_run in dag_runs], + total_entries=total_entries, + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index fe281c5640f9e..736425218022a 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -396,6 +396,65 @@ export const UseDagRunServiceGetUpstreamAssetEventsKeyFn = ( useDagRunServiceGetUpstreamAssetEventsKey, ...(queryKey ?? [{ dagId, dagRunId }]), ]; +export type DagRunServiceGetDagRunsDefaultResponse = Awaited< + ReturnType +>; +export type DagRunServiceGetDagRunsQueryResult< + TData = DagRunServiceGetDagRunsDefaultResponse, + TError = unknown, +> = UseQueryResult; +export const useDagRunServiceGetDagRunsKey = "DagRunServiceGetDagRuns"; +export const UseDagRunServiceGetDagRunsKeyFn = ( + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }: { + dagId: string; + endDateGte?: string; + endDateLte?: string; + limit?: number; + logicalDateGte?: string; + logicalDateLte?: string; + offset?: number; + orderBy?: string; + startDateGte?: string; + startDateLte?: string; + state?: string[]; + updatedAtGte?: string; + updatedAtLte?: string; + }, + queryKey?: Array, +) => [ + useDagRunServiceGetDagRunsKey, + ...(queryKey ?? [ + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }, + ]), +]; export type DagSourceServiceGetDagSourceDefaultResponse = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts b/airflow/ui/openapi-gen/queries/prefetch.ts index f7872fcc7f8cc..d57dfe3d24d98 100644 --- a/airflow/ui/openapi-gen/queries/prefetch.ts +++ b/airflow/ui/openapi-gen/queries/prefetch.ts @@ -492,6 +492,93 @@ export const prefetchUseDagRunServiceGetUpstreamAssetEvents = ( }), queryFn: () => DagRunService.getUpstreamAssetEvents({ dagId, dagRunId }), }); +/** + * Get Dag Runs + * Get all DAG Runs. + * + * This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. + * @param data The data for the request. + * @param data.dagId + * @param data.limit + * @param data.offset + * @param data.logicalDateGte + * @param data.logicalDateLte + * @param data.startDateGte + * @param data.startDateLte + * @param data.endDateGte + * @param data.endDateLte + * @param data.updatedAtGte + * @param data.updatedAtLte + * @param data.state + * @param data.orderBy + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ +export const prefetchUseDagRunServiceGetDagRuns = ( + queryClient: QueryClient, + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }: { + dagId: string; + endDateGte?: string; + endDateLte?: string; + limit?: number; + logicalDateGte?: string; + logicalDateLte?: string; + offset?: number; + orderBy?: string; + startDateGte?: string; + startDateLte?: string; + state?: string[]; + updatedAtGte?: string; + updatedAtLte?: string; + }, +) => + queryClient.prefetchQuery({ + queryKey: Common.UseDagRunServiceGetDagRunsKeyFn({ + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }), + queryFn: () => + DagRunService.getDagRuns({ + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }), + }); /** * Get Dag Source * Get source code using file token. diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 74e25c0258a25..2ca159f465f51 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -622,6 +622,102 @@ export const useDagRunServiceGetUpstreamAssetEvents = < DagRunService.getUpstreamAssetEvents({ dagId, dagRunId }) as TData, ...options, }); +/** + * Get Dag Runs + * Get all DAG Runs. + * + * This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. + * @param data The data for the request. + * @param data.dagId + * @param data.limit + * @param data.offset + * @param data.logicalDateGte + * @param data.logicalDateLte + * @param data.startDateGte + * @param data.startDateLte + * @param data.endDateGte + * @param data.endDateLte + * @param data.updatedAtGte + * @param data.updatedAtLte + * @param data.state + * @param data.orderBy + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagRunServiceGetDagRuns = < + TData = Common.DagRunServiceGetDagRunsDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }: { + dagId: string; + endDateGte?: string; + endDateLte?: string; + limit?: number; + logicalDateGte?: string; + logicalDateLte?: string; + offset?: number; + orderBy?: string; + startDateGte?: string; + startDateLte?: string; + state?: string[]; + updatedAtGte?: string; + updatedAtLte?: string; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useQuery({ + queryKey: Common.UseDagRunServiceGetDagRunsKeyFn( + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }, + queryKey, + ), + queryFn: () => + DagRunService.getDagRuns({ + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }) as TData, + ...options, + }); /** * Get Dag Source * Get source code using file token. diff --git a/airflow/ui/openapi-gen/queries/suspense.ts b/airflow/ui/openapi-gen/queries/suspense.ts index 87b1a7aa6a2ba..50ccc8a3c6820 100644 --- a/airflow/ui/openapi-gen/queries/suspense.ts +++ b/airflow/ui/openapi-gen/queries/suspense.ts @@ -604,6 +604,102 @@ export const useDagRunServiceGetUpstreamAssetEventsSuspense = < DagRunService.getUpstreamAssetEvents({ dagId, dagRunId }) as TData, ...options, }); +/** + * Get Dag Runs + * Get all DAG Runs. + * + * This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. + * @param data The data for the request. + * @param data.dagId + * @param data.limit + * @param data.offset + * @param data.logicalDateGte + * @param data.logicalDateLte + * @param data.startDateGte + * @param data.startDateLte + * @param data.endDateGte + * @param data.endDateLte + * @param data.updatedAtGte + * @param data.updatedAtLte + * @param data.state + * @param data.orderBy + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagRunServiceGetDagRunsSuspense = < + TData = Common.DagRunServiceGetDagRunsDefaultResponse, + TError = unknown, + TQueryKey extends Array = unknown[], +>( + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }: { + dagId: string; + endDateGte?: string; + endDateLte?: string; + limit?: number; + logicalDateGte?: string; + logicalDateLte?: string; + offset?: number; + orderBy?: string; + startDateGte?: string; + startDateLte?: string; + state?: string[]; + updatedAtGte?: string; + updatedAtLte?: string; + }, + queryKey?: TQueryKey, + options?: Omit, "queryKey" | "queryFn">, +) => + useSuspenseQuery({ + queryKey: Common.UseDagRunServiceGetDagRunsKeyFn( + { + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }, + queryKey, + ), + queryFn: () => + DagRunService.getDagRuns({ + dagId, + endDateGte, + endDateLte, + limit, + logicalDateGte, + logicalDateLte, + offset, + orderBy, + startDateGte, + startDateLte, + state, + updatedAtGte, + updatedAtLte, + }) as TData, + ...options, + }); /** * Get Dag Source * Get source code using file token. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 0f08034c533dd..a0bb85ace80ad 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -1404,6 +1404,26 @@ export const $DAGRunClearBody = { description: "DAG Run serializer for clear endpoint body.", } as const; +export const $DAGRunCollectionResponse = { + properties: { + dag_runs: { + items: { + $ref: "#/components/schemas/DAGRunResponse", + }, + type: "array", + title: "Dag Runs", + }, + total_entries: { + type: "integer", + title: "Total Entries", + }, + }, + type: "object", + required: ["dag_runs", "total_entries"], + title: "DAGRunCollectionResponse", + description: "DAG Run Collection serializer for responses.", +} as const; + export const $DAGRunPatchBody = { properties: { state: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 467c2961420e5..63272c5e0c470 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -63,6 +63,8 @@ import type { GetUpstreamAssetEventsResponse, ClearDagRunData, ClearDagRunResponse, + GetDagRunsData, + GetDagRunsResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, @@ -1033,6 +1035,60 @@ export class DagRunService { }, }); } + + /** + * Get Dag Runs + * Get all DAG Runs. + * + * This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs. + * @param data The data for the request. + * @param data.dagId + * @param data.limit + * @param data.offset + * @param data.logicalDateGte + * @param data.logicalDateLte + * @param data.startDateGte + * @param data.startDateLte + * @param data.endDateGte + * @param data.endDateLte + * @param data.updatedAtGte + * @param data.updatedAtLte + * @param data.state + * @param data.orderBy + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ + public static getDagRuns( + data: GetDagRunsData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "GET", + url: "/public/dags/{dag_id}/dagRuns", + path: { + dag_id: data.dagId, + }, + query: { + limit: data.limit, + offset: data.offset, + logical_date_gte: data.logicalDateGte, + logical_date_lte: data.logicalDateLte, + start_date_gte: data.startDateGte, + start_date_lte: data.startDateLte, + end_date_gte: data.endDateGte, + end_date_lte: data.endDateLte, + updated_at_gte: data.updatedAtGte, + updated_at_lte: data.updatedAtLte, + state: data.state, + order_by: data.orderBy, + }, + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class DagSourceService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 21ad2839bc80b..926932379b350 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -313,6 +313,14 @@ export type DAGRunClearBody = { dry_run?: boolean; }; +/** + * DAG Run Collection serializer for responses. + */ +export type DAGRunCollectionResponse = { + dag_runs: Array; + total_entries: number; +}; + /** * DAG Run Serializer for PATCH requests. */ @@ -1300,6 +1308,24 @@ export type ClearDagRunResponse = | TaskInstanceCollectionResponse | DAGRunResponse; +export type GetDagRunsData = { + dagId: string; + endDateGte?: string | null; + endDateLte?: string | null; + limit?: number; + logicalDateGte?: string | null; + logicalDateLte?: string | null; + offset?: number; + orderBy?: string; + startDateGte?: string | null; + startDateLte?: string | null; + state?: Array; + updatedAtGte?: string | null; + updatedAtLte?: string | null; +}; + +export type GetDagRunsResponse = DAGRunCollectionResponse; + export type GetDagSourceData = { accept?: "application/json" | "text/plain" | "*/*"; dagId: string; @@ -2459,6 +2485,33 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/dagRuns": { + get: { + req: GetDagRunsData; + res: { + /** + * Successful Response + */ + 200: DAGRunCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dagSources/{dag_id}": { get: { req: GetDagSourceData; diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 89705ba85ab68..2ac22a02e31aa 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -17,7 +17,7 @@ from __future__ import annotations -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import pytest from sqlalchemy import select @@ -52,9 +52,12 @@ DAG1_RUN2_TRIGGERED_BY = DagRunTriggeredByType.ASSET DAG2_RUN1_TRIGGERED_BY = DagRunTriggeredByType.CLI DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API -START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc) -END_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc) -EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc) +START_DATE1 = datetime(2024, 1, 15, 0, 0, tzinfo=timezone.utc) +LOGICAL_DATE1 = datetime(2024, 2, 16, 0, 0, tzinfo=timezone.utc) +LOGICAL_DATE2 = datetime(2024, 2, 20, 0, 0, tzinfo=timezone.utc) +START_DATE2 = datetime(2024, 4, 15, 0, 0, tzinfo=timezone.utc) +LOGICAL_DATE3 = datetime(2024, 5, 16, 0, 0, tzinfo=timezone.utc) +LOGICAL_DATE4 = datetime(2024, 5, 25, 0, 0, tzinfo=timezone.utc) DAG1_RUN1_NOTE = "test_note" @@ -68,7 +71,7 @@ def setup(dag_maker, session=None): with dag_maker( DAG1_ID, schedule="@daily", - start_date=START_DATE, + start_date=START_DATE1, ): task1 = EmptyOperator(task_id="task_1") dag_run1 = dag_maker.create_dagrun( @@ -76,6 +79,7 @@ def setup(dag_maker, session=None): state=DAG1_RUN1_STATE, run_type=DAG1_RUN1_RUN_TYPE, triggered_by=DAG1_RUN1_TRIGGERED_BY, + logical_date=LOGICAL_DATE1, ) dag_run1.note = (DAG1_RUN1_NOTE, 1) @@ -89,13 +93,13 @@ def setup(dag_maker, session=None): state=DAG1_RUN2_STATE, run_type=DAG1_RUN2_RUN_TYPE, triggered_by=DAG1_RUN2_TRIGGERED_BY, - logical_date=EXECUTION_DATE, + logical_date=LOGICAL_DATE2, ) with dag_maker( DAG2_ID, schedule=None, - start_date=START_DATE, + start_date=START_DATE2, ): EmptyOperator(task_id="task_2") dag_maker.create_dagrun( @@ -103,14 +107,14 @@ def setup(dag_maker, session=None): state=DAG2_RUN1_STATE, run_type=DAG2_RUN1_RUN_TYPE, triggered_by=DAG2_RUN1_TRIGGERED_BY, - logical_date=EXECUTION_DATE, + logical_date=LOGICAL_DATE3, ) dag_maker.create_dagrun( run_id=DAG2_RUN2_ID, state=DAG2_RUN2_STATE, run_type=DAG2_RUN2_RUN_TYPE, triggered_by=DAG2_RUN2_TRIGGERED_BY, - logical_date=EXECUTION_DATE, + logical_date=LOGICAL_DATE4, ) dag_maker.dagbag.sync_to_db() @@ -156,6 +160,275 @@ def test_get_dag_run_not_found(self, test_client): assert body["detail"] == "The DagRun with dag_id: `test_dag1` and run_id: `invalid` was not found" +class TestGetDagRuns: + @staticmethod + def parse_datetime(datetime_str): + return datetime_str.isoformat().replace("+00:00", "Z") if datetime_str else None + + @staticmethod + def get_dag_run_dict(run: DagRun): + return { + "run_id": run.run_id, + "dag_id": run.dag_id, + "logical_date": TestGetDagRuns.parse_datetime(run.logical_date), + "queued_at": TestGetDagRuns.parse_datetime(run.queued_at), + "start_date": TestGetDagRuns.parse_datetime(run.start_date), + "end_date": TestGetDagRuns.parse_datetime(run.end_date), + "data_interval_start": TestGetDagRuns.parse_datetime(run.data_interval_start), + "data_interval_end": TestGetDagRuns.parse_datetime(run.data_interval_end), + "last_scheduling_decision": TestGetDagRuns.parse_datetime(run.last_scheduling_decision), + "run_type": run.run_type, + "state": run.state, + "external_trigger": run.external_trigger, + "triggered_by": run.triggered_by.value, + "conf": run.conf, + "note": run.note, + } + + @pytest.mark.parametrize("dag_id, total_entries", [(DAG1_ID, 2), (DAG2_ID, 2), ("~", 4)]) + def test_get_dag_runs(self, test_client, session, dag_id, total_entries): + response = test_client.get(f"/public/dags/{dag_id}/dagRuns") + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == total_entries + for each in body["dag_runs"]: + run = ( + session.query(DagRun) + .where(DagRun.dag_id == each["dag_id"], DagRun.run_id == each["run_id"]) + .one() + ) + expected = self.get_dag_run_dict(run) + assert each == expected + + def test_get_dag_runs_not_found(self, test_client): + response = test_client.get("/public/dags/invalid/dagRuns") + assert response.status_code == 404 + body = response.json() + assert body["detail"] == "The DAG with dag_id: `invalid` was not found" + + def test_invalid_order_by_raises_400(self, test_client): + response = test_client.get("/public/dags/test_dag1/dagRuns?order_by=invalid") + assert response.status_code == 400 + body = response.json() + assert ( + body["detail"] + == "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" + ) + + @pytest.mark.parametrize( + "order_by, expected_dag_id_order", + [ + ("id", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("state", [DAG1_RUN2_ID, DAG1_RUN1_ID]), + ("dag_id", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("logical_date", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("dag_run_id", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("start_date", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("end_date", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("updated_at", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("external_trigger", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ("conf", [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ], + ) + def test_return_correct_results_with_order_by(self, test_client, order_by, expected_dag_id_order): + response = test_client.get("/public/dags/test_dag1/dagRuns", params={"order_by": order_by}) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 2 + assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_order + + @pytest.mark.parametrize( + "query_params, expected_dag_id_order", + [ + ({}, [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ({"limit": 1}, [DAG1_RUN1_ID]), + ({"limit": 3}, [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ({"offset": 1}, [DAG1_RUN2_ID]), + ({"offset": 2}, []), + ({"limit": 1, "offset": 1}, [DAG1_RUN2_ID]), + ({"limit": 1, "offset": 2}, []), + ], + ) + def test_limit_and_offset(self, test_client, query_params, expected_dag_id_order): + response = test_client.get("/public/dags/test_dag1/dagRuns", params=query_params) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 2 + assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_order + + @pytest.mark.parametrize( + "query_params, expected_detail", + [ + ( + {"limit": 1, "offset": -1}, + [ + { + "type": "greater_than_equal", + "loc": ["query", "offset"], + "msg": "Input should be greater than or equal to 0", + "input": "-1", + "ctx": {"ge": 0}, + } + ], + ), + ( + {"limit": -1, "offset": 1}, + [ + { + "type": "greater_than_equal", + "loc": ["query", "limit"], + "msg": "Input should be greater than or equal to 0", + "input": "-1", + "ctx": {"ge": 0}, + } + ], + ), + ( + {"limit": -1, "offset": -1}, + [ + { + "type": "greater_than_equal", + "loc": ["query", "limit"], + "msg": "Input should be greater than or equal to 0", + "input": "-1", + "ctx": {"ge": 0}, + }, + { + "type": "greater_than_equal", + "loc": ["query", "offset"], + "msg": "Input should be greater than or equal to 0", + "input": "-1", + "ctx": {"ge": 0}, + }, + ], + ), + ], + ) + def test_bad_limit_and_offset(self, test_client, query_params, expected_detail): + response = test_client.get("/public/dags/test_dag1/dagRuns", params=query_params) + assert response.status_code == 422 + assert response.json()["detail"] == expected_detail + + @pytest.mark.parametrize( + "dag_id, query_params, expected_dag_id_list", + [ + (DAG1_ID, {"logical_date_gte": LOGICAL_DATE1.isoformat()}, [DAG1_RUN1_ID, DAG1_RUN2_ID]), + (DAG2_ID, {"logical_date_lte": LOGICAL_DATE3.isoformat()}, [DAG2_RUN1_ID]), + ( + "~", + { + "start_date_gte": START_DATE1.isoformat(), + "start_date_lte": (START_DATE2 - timedelta(days=1)).isoformat(), + }, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), + ( + DAG1_ID, + { + "end_date_gte": START_DATE2.isoformat(), + "end_date_lte": (datetime.now(tz=timezone.utc) + timedelta(days=1)).isoformat(), + }, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), + ( + DAG1_ID, + { + "logical_date_gte": LOGICAL_DATE1.isoformat(), + "logical_date_lte": LOGICAL_DATE2.isoformat(), + }, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), + ( + DAG2_ID, + { + "start_date_gte": START_DATE2.isoformat(), + "end_date_lte": (datetime.now(tz=timezone.utc) + timedelta(days=1)).isoformat(), + }, + [DAG2_RUN1_ID, DAG2_RUN2_ID], + ), + (DAG1_ID, {"state": DagRunState.SUCCESS.value}, [DAG1_RUN1_ID]), + (DAG2_ID, {"state": DagRunState.FAILED.value}, []), + ( + DAG1_ID, + {"state": DagRunState.SUCCESS.value, "logical_date_gte": LOGICAL_DATE1.isoformat()}, + [DAG1_RUN1_ID], + ), + ( + DAG1_ID, + {"state": DagRunState.FAILED.value, "start_date_gte": START_DATE1.isoformat()}, + [DAG1_RUN2_ID], + ), + ], + ) + def test_filters(self, test_client, dag_id, query_params, expected_dag_id_list): + response = test_client.get(f"/public/dags/{dag_id}/dagRuns", params=query_params) + assert response.status_code == 200 + body = response.json() + assert [each["run_id"] for each in body["dag_runs"]] == expected_dag_id_list + + def test_bad_filters(self, test_client): + query_params = { + "logical_date_gte": "invalid", + "start_date_gte": "invalid", + "end_date_gte": "invalid", + "logical_date_lte": "invalid", + "start_date_lte": "invalid", + "end_date_lte": "invalid", + } + expected_detail = [ + { + "type": "datetime_from_date_parsing", + "loc": ["query", "logical_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["query", "logical_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["query", "start_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["query", "start_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["query", "end_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["query", "end_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + ] + response = test_client.get(f"/public/dags/{DAG1_ID}/dagRuns", params=query_params) + assert response.status_code == 422 + body = response.json() + assert body["detail"] == expected_detail + + def test_invalid_state(self, test_client): + with pytest.raises(ValueError, match="'invalid' is not a valid DagRunState"): + test_client.get(f"/public/dags/{DAG1_ID}/dagRuns", params={"state": "invalid"}) + + class TestPatchDagRun: @pytest.mark.parametrize( "dag_id, run_id, patch_body, response_body", @@ -271,7 +544,7 @@ class TestGetDagRunAssetTriggerEvents: def test_should_respond_200(self, test_client, dag_maker, session): asset1 = Asset(uri="ds1") - with dag_maker(dag_id="source_dag", start_date=START_DATE, session=session): + with dag_maker(dag_id="source_dag", start_date=START_DATE1, session=session): EmptyOperator(task_id="task", outlets=[asset1]) dr = dag_maker.create_dagrun() ti = dr.task_instances[0] @@ -286,7 +559,7 @@ def test_should_respond_200(self, test_client, dag_maker, session): ) session.add(event) - with dag_maker(dag_id="TEST_DAG_ID", start_date=START_DATE, session=session): + with dag_maker(dag_id="TEST_DAG_ID", start_date=START_DATE1, session=session): pass dr = dag_maker.create_dagrun(run_id="TEST_DAG_RUN_ID", run_type=DagRunType.ASSET_TRIGGERED) dr.consumed_asset_events.append(event)