From d8c91aa6195104a5cf994fc2086d31d6743a6e40 Mon Sep 17 00:00:00 2001 From: Kalyan R Date: Tue, 26 Nov 2024 15:12:52 +0530 Subject: [PATCH] AIP-84 Migrate POST list Dag Runs(batch) endpoint to FastAPI (#44170) * init list dag runs batch * finish dag runs batch * remove all() for scalars * working tests * fix * fix * update tests to use dag_run_id instead of run_id * fix test * add tests for reverse order * refactor * refactor --- .../endpoints/dag_run_endpoint.py | 1 + .../core_api/datamodels/dag_run.py | 18 +- .../core_api/openapi/v1-generated.yaml | 124 ++++++ .../core_api/routes/public/dag_run.py | 65 ++- airflow/ui/openapi-gen/queries/common.ts | 3 + airflow/ui/openapi-gen/queries/queries.ts | 44 ++ .../ui/openapi-gen/requests/schemas.gen.ts | 138 ++++++ .../ui/openapi-gen/requests/services.gen.ts | 31 ++ airflow/ui/openapi-gen/requests/types.gen.ts | 51 +++ .../core_api/routes/public/test_dag_run.py | 412 +++++++++++++++++- 10 files changed, 873 insertions(+), 14 deletions(-) diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 316ab2b8a7de7..dadfb3e4f42f6 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -265,6 +265,7 @@ def get_dag_runs( raise BadRequest("DAGRunCollectionSchema error", detail=str(e)) +@mark_fastapi_migration_done @security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow/api_fastapi/core_api/datamodels/dag_run.py index f569e59a3e032..55240d15e55ff 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -20,7 +20,7 @@ from datetime import datetime from enum import Enum -from pydantic import Field +from pydantic import AwareDatetime, Field, NonNegativeInt from airflow.api_fastapi.core_api.base import BaseModel from airflow.utils.state import DagRunState @@ -73,3 +73,19 @@ class DAGRunCollectionResponse(BaseModel): dag_runs: list[DAGRunResponse] total_entries: int + + +class DAGRunsBatchBody(BaseModel): + """List DAG Runs body for batch endpoint.""" + + order_by: str | None = None + page_offset: NonNegativeInt = 0 + page_limit: NonNegativeInt = 100 + dag_ids: list[str] | None = None + states: list[DagRunState | None] | None = None + logical_date_gte: AwareDatetime | None = None + logical_date_lte: AwareDatetime | None = None + start_date_gte: AwareDatetime | None = None + start_date_lte: AwareDatetime | None = None + end_date_gte: AwareDatetime | None = None + end_date_lte: AwareDatetime | None = None diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index e0ce72573bb76..30f409bff682d 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1828,6 +1828,58 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + /public/dags/{dag_id}/dagRuns/list: + post: + tags: + - DagRun + summary: Get List Dag Runs Batch + description: Get a list of DAG Runs. + operationId: get_list_dag_runs_batch + parameters: + - name: dag_id + in: path + required: true + schema: + const: '~' + type: string + title: Dag Id + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunsBatchBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunCollectionResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dagSources/{dag_id}: get: tags: @@ -6342,6 +6394,78 @@ components: - asset_triggered title: DAGRunTypes description: DAG Run Types for responses. + DAGRunsBatchBody: + properties: + order_by: + anyOf: + - type: string + - type: 'null' + title: Order By + page_offset: + type: integer + minimum: 0.0 + title: Page Offset + default: 0 + page_limit: + type: integer + minimum: 0.0 + title: Page Limit + default: 100 + dag_ids: + anyOf: + - items: + type: string + type: array + - type: 'null' + title: Dag Ids + states: + anyOf: + - items: + anyOf: + - $ref: '#/components/schemas/DagRunState' + - type: 'null' + type: array + - type: 'null' + title: States + logical_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Gte + logical_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Logical Date Lte + start_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Gte + start_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Start Date Lte + end_date_gte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Gte + end_date_lte: + anyOf: + - type: string + format: date-time + - type: 'null' + title: End Date Lte + type: object + title: DAGRunsBatchBody + description: List DAG Runs body for batch endpoint. DAGSourceResponse: properties: content: diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index e9905cbc83a94..5b3eb0e66e505 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import Annotated, cast +from typing import Annotated, Literal, cast from fastapi import Depends, HTTPException, Query, Request, status from sqlalchemy import select @@ -30,9 +30,13 @@ ) from airflow.api_fastapi.common.db.common import get_session, paginated_select from airflow.api_fastapi.common.parameters import ( + DagIdsFilter, + LimitFilter, + OffsetFilter, QueryDagRunStateFilter, QueryLimit, QueryOffset, + Range, RangeFilter, SortParam, datetime_range_filter_factory, @@ -45,6 +49,7 @@ DAGRunPatchBody, DAGRunPatchStates, DAGRunResponse, + DAGRunsBatchBody, ) from airflow.api_fastapi.core_api.datamodels.task_instances import ( TaskInstanceCollectionResponse, @@ -296,3 +301,61 @@ def get_dag_runs( dag_runs=dag_runs, total_entries=total_entries, ) + + +@dag_run_router.post("/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) +def get_list_dag_runs_batch( + dag_id: Literal["~"], body: DAGRunsBatchBody, session: Annotated[Session, Depends(get_session)] +) -> DAGRunCollectionResponse: + """Get a list of DAG Runs.""" + dag_ids = DagIdsFilter(DagRun, body.dag_ids) + logical_date = RangeFilter( + Range(lower_bound=body.logical_date_gte, upper_bound=body.logical_date_lte), + attribute=DagRun.logical_date, + ) + start_date = RangeFilter( + Range(lower_bound=body.start_date_gte, upper_bound=body.start_date_lte), + attribute=DagRun.start_date, + ) + end_date = RangeFilter( + Range(lower_bound=body.end_date_gte, upper_bound=body.end_date_lte), + attribute=DagRun.end_date, + ) + + state = QueryDagRunStateFilter(body.states) + + offset = OffsetFilter(body.page_offset) + limit = LimitFilter(body.page_limit) + + order_by = SortParam( + [ + "id", + "state", + "dag_id", + "logical_date", + "dag_run_id", + "start_date", + "end_date", + "updated_at", + "external_trigger", + "conf", + ], + DagRun, + ).set_value(body.order_by) + + base_query = select(DagRun) + dag_runs_select, total_entries = paginated_select( + statement=base_query, + filters=[dag_ids, logical_date, start_date, end_date, state], + order_by=order_by, + offset=offset, + limit=limit, + session=session, + ) + + dag_runs = session.scalars(dag_runs_select) + + return DAGRunCollectionResponse( + dag_runs=dag_runs, + total_entries=total_entries, + ) diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 86c9bb9819c05..49bf81250962f 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1513,6 +1513,9 @@ export type ConnectionServiceTestConnectionMutationResult = Awaited< export type DagRunServiceClearDagRunMutationResult = Awaited< ReturnType >; +export type DagRunServiceGetListDagRunsBatchMutationResult = Awaited< + ReturnType +>; export type TaskInstanceServiceGetTaskInstancesBatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 0dcbf340df550..4e74429568206 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -40,6 +40,7 @@ import { DAGPatchBody, DAGRunClearBody, DAGRunPatchBody, + DAGRunsBatchBody, DagRunState, DagWarningType, PoolPatchBody, @@ -2619,6 +2620,49 @@ export const useDagRunServiceClearDagRun = < }) as unknown as Promise, ...options, }); +/** + * Get List Dag Runs Batch + * Get a list of DAG Runs. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ +export const useDagRunServiceGetListDagRunsBatch = < + TData = Common.DagRunServiceGetListDagRunsBatchMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: "~"; + requestBody: DAGRunsBatchBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: "~"; + requestBody: DAGRunsBatchBody; + }, + TContext + >({ + mutationFn: ({ dagId, requestBody }) => + DagRunService.getListDagRunsBatch({ + dagId, + requestBody, + }) as unknown as Promise, + ...options, + }); /** * Get Task Instances Batch * Get list of task instances. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 660dab68b8c4b..b311deb809576 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -1870,6 +1870,144 @@ export const $DAGRunTypes = { description: "DAG Run Types for responses.", } as const; +export const $DAGRunsBatchBody = { + properties: { + order_by: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Order By", + }, + page_offset: { + type: "integer", + minimum: 0, + title: "Page Offset", + default: 0, + }, + page_limit: { + type: "integer", + minimum: 0, + title: "Page Limit", + default: 100, + }, + dag_ids: { + anyOf: [ + { + items: { + type: "string", + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "Dag Ids", + }, + states: { + anyOf: [ + { + items: { + anyOf: [ + { + $ref: "#/components/schemas/DagRunState", + }, + { + type: "null", + }, + ], + }, + type: "array", + }, + { + type: "null", + }, + ], + title: "States", + }, + logical_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Logical Date Gte", + }, + logical_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Logical Date Lte", + }, + start_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date Gte", + }, + start_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Start Date Lte", + }, + end_date_gte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date Gte", + }, + end_date_lte: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "End Date Lte", + }, + }, + type: "object", + title: "DAGRunsBatchBody", + description: "List DAG Runs body for batch endpoint.", +} as const; + export const $DAGSourceResponse = { properties: { content: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index cbb74b4395ca9..a058c537c285e 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -70,6 +70,8 @@ import type { ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, + GetListDagRunsBatchData, + GetListDagRunsBatchResponse, GetDagSourceData, GetDagSourceResponse, GetDagStatsData, @@ -1182,6 +1184,35 @@ export class DagRunService { }, }); } + + /** + * Get List Dag Runs Batch + * Get a list of DAG Runs. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns DAGRunCollectionResponse Successful Response + * @throws ApiError + */ + public static getListDagRunsBatch( + data: GetListDagRunsBatchData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/dags/{dag_id}/dagRuns/list", + path: { + dag_id: data.dagId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 422: "Validation Error", + }, + }); + } } export class DagSourceService { diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index a434871181ea4..e77c2dbce52ca 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -423,6 +423,23 @@ export type DAGRunTypes = { asset_triggered: number; }; +/** + * List DAG Runs body for batch endpoint. + */ +export type DAGRunsBatchBody = { + order_by?: string | null; + page_offset?: number; + page_limit?: number; + dag_ids?: Array | null; + states?: Array | null; + logical_date_gte?: string | null; + logical_date_lte?: string | null; + start_date_gte?: string | null; + start_date_lte?: string | null; + end_date_gte?: string | null; + end_date_lte?: string | null; +}; + /** * DAG Source serializer for responses. */ @@ -1444,6 +1461,13 @@ export type GetDagRunsData = { export type GetDagRunsResponse = DAGRunCollectionResponse; +export type GetListDagRunsBatchData = { + dagId: "~"; + requestBody: DAGRunsBatchBody; +}; + +export type GetListDagRunsBatchResponse = DAGRunCollectionResponse; + export type GetDagSourceData = { accept?: "application/json" | "text/plain" | "*/*"; dagId: string; @@ -2752,6 +2776,33 @@ export type $OpenApiTs = { }; }; }; + "/public/dags/{dag_id}/dagRuns/list": { + post: { + req: GetListDagRunsBatchData; + res: { + /** + * Successful Response + */ + 200: DAGRunCollectionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; + }; "/public/dagSources/{dag_id}": { get: { req: GetDagSourceData; diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 8d0be6fbeafcf..48396be497698 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -30,7 +30,11 @@ from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunTriggeredByType, DagRunType -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests_common.test_utils.db import ( + clear_db_dags, + clear_db_runs, + clear_db_serialized_dags, +) pytestmark = pytest.mark.db_test @@ -60,6 +64,8 @@ LOGICAL_DATE4 = datetime(2024, 5, 25, 0, 0, tzinfo=timezone.utc) DAG1_RUN1_NOTE = "test_note" +DAG_RUNS_LIST = [DAG1_RUN1_ID, DAG1_RUN2_ID, DAG2_RUN1_ID, DAG2_RUN2_ID] + @pytest.fixture(autouse=True) @provide_session @@ -137,9 +143,30 @@ class TestGetDagRun: DAG1_RUN1_TRIGGERED_BY, DAG1_RUN1_NOTE, ), - (DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, DAG1_RUN2_TRIGGERED_BY, None), - (DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, DAG2_RUN1_TRIGGERED_BY, None), - (DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, DAG2_RUN2_TRIGGERED_BY, None), + ( + DAG1_ID, + DAG1_RUN2_ID, + DAG1_RUN2_STATE, + DAG1_RUN2_RUN_TYPE, + DAG1_RUN2_TRIGGERED_BY, + None, + ), + ( + DAG2_ID, + DAG2_RUN1_ID, + DAG2_RUN1_STATE, + DAG2_RUN1_RUN_TYPE, + DAG2_RUN1_TRIGGERED_BY, + None, + ), + ( + DAG2_ID, + DAG2_RUN2_ID, + DAG2_RUN2_STATE, + DAG2_RUN2_RUN_TYPE, + DAG2_RUN2_TRIGGERED_BY, + None, + ), ], ) def test_get_dag_run(self, test_client, dag_id, run_id, state, run_type, triggered_by, dag_run_note): @@ -312,7 +339,11 @@ def test_bad_limit_and_offset(self, test_client, query_params, expected_detail): @pytest.mark.parametrize( "dag_id, query_params, expected_dag_id_list", [ - (DAG1_ID, {"logical_date_gte": LOGICAL_DATE1.isoformat()}, [DAG1_RUN1_ID, DAG1_RUN2_ID]), + ( + DAG1_ID, + {"logical_date_gte": LOGICAL_DATE1.isoformat()}, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), (DAG2_ID, {"logical_date_lte": LOGICAL_DATE3.isoformat()}, [DAG2_RUN1_ID]), ( "~", @@ -350,12 +381,18 @@ def test_bad_limit_and_offset(self, test_client, query_params, expected_detail): (DAG2_ID, {"state": DagRunState.FAILED.value}, []), ( DAG1_ID, - {"state": DagRunState.SUCCESS.value, "logical_date_gte": LOGICAL_DATE1.isoformat()}, + { + "state": DagRunState.SUCCESS.value, + "logical_date_gte": LOGICAL_DATE1.isoformat(), + }, [DAG1_RUN1_ID], ), ( DAG1_ID, - {"state": DagRunState.FAILED.value, "start_date_gte": START_DATE1.isoformat()}, + { + "state": DagRunState.FAILED.value, + "start_date_gte": START_DATE1.isoformat(), + }, [DAG1_RUN2_ID], ), ], @@ -432,6 +469,343 @@ def test_invalid_state(self, test_client): ) +class TestListDagRunsBatch: + @staticmethod + def parse_datetime(datetime_str): + return datetime_str.isoformat().replace("+00:00", "Z") if datetime_str else None + + @staticmethod + def get_dag_run_dict(run: DagRun): + return { + "dag_run_id": run.run_id, + "dag_id": run.dag_id, + "logical_date": TestGetDagRuns.parse_datetime(run.logical_date), + "queued_at": TestGetDagRuns.parse_datetime(run.queued_at), + "start_date": TestGetDagRuns.parse_datetime(run.start_date), + "end_date": TestGetDagRuns.parse_datetime(run.end_date), + "data_interval_start": TestGetDagRuns.parse_datetime(run.data_interval_start), + "data_interval_end": TestGetDagRuns.parse_datetime(run.data_interval_end), + "last_scheduling_decision": TestGetDagRuns.parse_datetime(run.last_scheduling_decision), + "run_type": run.run_type, + "state": run.state, + "external_trigger": run.external_trigger, + "triggered_by": run.triggered_by.value, + "conf": run.conf, + "note": run.note, + } + + def test_list_dag_runs_return_200(self, test_client, session): + response = test_client.post("/public/dags/~/dagRuns/list", json={}) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 4 + for each in body["dag_runs"]: + run = session.query(DagRun).where(DagRun.run_id == each["dag_run_id"]).one() + expected = self.get_dag_run_dict(run) + assert each == expected + + def test_list_dag_runs_with_invalid_dag_id(self, test_client): + response = test_client.post("/public/dags/invalid/dagRuns/list", json={}) + assert response.status_code == 422 + body = response.json() + assert body["detail"] == [ + { + "type": "literal_error", + "loc": ["path", "dag_id"], + "msg": "Input should be '~'", + "input": "invalid", + "ctx": {"expected": "'~'"}, + } + ] + + @pytest.mark.parametrize( + "dag_ids, status_code, expected_dag_id_list", + [ + ([], 200, DAG_RUNS_LIST), + ([DAG1_ID], 200, [DAG1_RUN1_ID, DAG1_RUN2_ID]), + [["invalid"], 200, []], + ], + ) + def test_list_dag_runs_with_dag_ids_filter(self, test_client, dag_ids, status_code, expected_dag_id_list): + response = test_client.post("/public/dags/~/dagRuns/list", json={"dag_ids": dag_ids}) + assert response.status_code == status_code + assert set([each["dag_run_id"] for each in response.json()["dag_runs"]]) == set(expected_dag_id_list) + + def test_invalid_order_by_raises_400(self, test_client): + response = test_client.post("/public/dags/~/dagRuns/list", json={"order_by": "invalid"}) + assert response.status_code == 400 + body = response.json() + assert ( + body["detail"] + == "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" + ) + + @pytest.mark.parametrize( + "order_by,expected_order", + [ + pytest.param("id", DAG_RUNS_LIST, id="order_by_id"), + pytest.param( + "state", [DAG1_RUN2_ID, DAG1_RUN1_ID, DAG2_RUN1_ID, DAG2_RUN2_ID], id="order_by_state" + ), + pytest.param("dag_id", DAG_RUNS_LIST, id="order_by_dag_id"), + pytest.param("logical_date", DAG_RUNS_LIST, id="order_by_logical_date"), + pytest.param("dag_run_id", DAG_RUNS_LIST, id="order_by_dag_run_id"), + pytest.param("start_date", DAG_RUNS_LIST, id="order_by_start_date"), + pytest.param("end_date", DAG_RUNS_LIST, id="order_by_end_date"), + pytest.param("updated_at", DAG_RUNS_LIST, id="order_by_updated_at"), + pytest.param("external_trigger", DAG_RUNS_LIST, id="order_by_external_trigger"), + pytest.param("conf", DAG_RUNS_LIST, id="order_by_conf"), + ], + ) + def test_dag_runs_ordering(self, test_client, order_by, expected_order): + # Test ascending order + response = test_client.post("/public/dags/~/dagRuns/list", json={"order_by": order_by}) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 4 + assert [run["dag_run_id"] for run in body["dag_runs"]] == expected_order + + # Test descending order + response = test_client.post("/public/dags/~/dagRuns/list", json={"order_by": f"-{order_by}"}) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 4 + assert [run["dag_run_id"] for run in body["dag_runs"]] == expected_order[::-1] + + @pytest.mark.parametrize( + "post_body, expected_dag_id_order", + [ + ({}, DAG_RUNS_LIST), + ({"page_limit": 1}, DAG_RUNS_LIST[:1]), + ({"page_limit": 3}, DAG_RUNS_LIST[:3]), + ({"page_offset": 1}, DAG_RUNS_LIST[1:]), + ({"page_offset": 5}, []), + ({"page_limit": 1, "page_offset": 1}, DAG_RUNS_LIST[1:2]), + ({"page_limit": 1, "page_offset": 2}, DAG_RUNS_LIST[2:3]), + ], + ) + def test_limit_and_offset(self, test_client, post_body, expected_dag_id_order): + response = test_client.post("/public/dags/~/dagRuns/list", json=post_body) + assert response.status_code == 200 + body = response.json() + assert body["total_entries"] == 4 + assert [each["dag_run_id"] for each in body["dag_runs"]] == expected_dag_id_order + + @pytest.mark.parametrize( + "post_body, expected_detail", + [ + ( + {"page_limit": 1, "page_offset": -1}, + [ + { + "type": "greater_than_equal", + "loc": ["body", "page_offset"], + "msg": "Input should be greater than or equal to 0", + "input": -1, + "ctx": {"ge": 0}, + } + ], + ), + ( + {"page_limit": -1, "offset": 1}, + [ + { + "type": "greater_than_equal", + "loc": ["body", "page_limit"], + "msg": "Input should be greater than or equal to 0", + "input": -1, + "ctx": {"ge": 0}, + } + ], + ), + ( + {"page_limit": -1, "page_offset": -1}, + [ + { + "type": "greater_than_equal", + "loc": ["body", "page_offset"], + "msg": "Input should be greater than or equal to 0", + "input": -1, + "ctx": {"ge": 0}, + }, + { + "type": "greater_than_equal", + "loc": ["body", "page_limit"], + "msg": "Input should be greater than or equal to 0", + "input": -1, + "ctx": {"ge": 0}, + }, + ], + ), + ], + ) + def test_bad_limit_and_offset(self, test_client, post_body, expected_detail): + response = test_client.post("/public/dags/~/dagRuns/list", json=post_body) + assert response.status_code == 422 + assert response.json()["detail"] == expected_detail + + @pytest.mark.parametrize( + "post_body, expected_dag_id_list", + [ + ( + {"logical_date_gte": LOGICAL_DATE1.isoformat()}, + DAG_RUNS_LIST, + ), + ({"logical_date_lte": LOGICAL_DATE3.isoformat()}, DAG_RUNS_LIST[:3]), + ( + { + "start_date_gte": START_DATE1.isoformat(), + "start_date_lte": (START_DATE2 - timedelta(days=1)).isoformat(), + }, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), + ( + { + "end_date_gte": START_DATE2.isoformat(), + "end_date_lte": (datetime.now(tz=timezone.utc) + timedelta(days=1)).isoformat(), + }, + DAG_RUNS_LIST, + ), + ( + { + "logical_date_gte": LOGICAL_DATE1.isoformat(), + "logical_date_lte": LOGICAL_DATE2.isoformat(), + }, + [DAG1_RUN1_ID, DAG1_RUN2_ID], + ), + ( + { + "start_date_gte": START_DATE2.isoformat(), + "end_date_lte": (datetime.now(tz=timezone.utc) + timedelta(days=1)).isoformat(), + }, + [DAG2_RUN1_ID, DAG2_RUN2_ID], + ), + ( + {"states": [DagRunState.SUCCESS.value]}, + [DAG1_RUN1_ID, DAG2_RUN1_ID, DAG2_RUN2_ID], + ), + ({"states": [DagRunState.FAILED.value]}, [DAG1_RUN2_ID]), + ( + { + "states": [DagRunState.SUCCESS.value], + "logical_date_gte": LOGICAL_DATE2.isoformat(), + }, + DAG_RUNS_LIST[2:], + ), + ( + { + "states": [DagRunState.FAILED.value], + "start_date_gte": START_DATE1.isoformat(), + }, + [DAG1_RUN2_ID], + ), + ], + ) + def test_filters(self, test_client, post_body, expected_dag_id_list): + response = test_client.post("/public/dags/~/dagRuns/list", json=post_body) + assert response.status_code == 200 + body = response.json() + assert [each["dag_run_id"] for each in body["dag_runs"]] == expected_dag_id_list + + def test_bad_filters(self, test_client): + post_body = { + "logical_date_gte": "invalid", + "start_date_gte": "invalid", + "end_date_gte": "invalid", + "logical_date_lte": "invalid", + "start_date_lte": "invalid", + "end_date_lte": "invalid", + "dag_ids": "invalid", + } + expected_detail = [ + { + "input": "invalid", + "loc": ["body", "dag_ids"], + "msg": "Input should be a valid list", + "type": "list_type", + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "logical_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "logical_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "start_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "start_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "end_date_gte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + { + "type": "datetime_from_date_parsing", + "loc": ["body", "end_date_lte"], + "msg": "Input should be a valid datetime or date, input is too short", + "input": "invalid", + "ctx": {"error": "input is too short"}, + }, + ] + response = test_client.post("/public/dags/~/dagRuns/list", json=post_body) + assert response.status_code == 422 + body = response.json() + assert body["detail"] == expected_detail + + @pytest.mark.parametrize( + "post_body, expected_response", + [ + ( + {"states": ["invalid"]}, + [ + { + "type": "enum", + "loc": ["body", "states", 0], + "msg": "Input should be 'queued', 'running', 'success' or 'failed'", + "input": "invalid", + "ctx": {"expected": "'queued', 'running', 'success' or 'failed'"}, + } + ], + ), + ( + {"states": "invalid"}, + [ + { + "type": "list_type", + "loc": ["body", "states"], + "msg": "Input should be a valid list", + "input": "invalid", + } + ], + ), + ], + ) + def test_invalid_state(self, test_client, post_body, expected_response): + response = test_client.post("/public/dags/~/dagRuns/list", json=post_body) + assert response.status_code == 422 + assert response.json()["detail"] == expected_response + + class TestPatchDagRun: @pytest.mark.parametrize( "dag_id, run_id, patch_body, response_body", @@ -466,7 +840,12 @@ class TestPatchDagRun: {"note": "new note", "state": DagRunState.FAILED}, {"state": DagRunState.FAILED, "note": "new note"}, ), - (DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": DagRunState.FAILED, "note": None}), + ( + DAG1_ID, + DAG1_RUN2_ID, + {"note": None}, + {"state": DagRunState.FAILED, "note": None}, + ), ], ) def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_body): @@ -481,7 +860,12 @@ def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_b @pytest.mark.parametrize( "query_params, patch_body, response_body, expected_status_code", [ - ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, {"state": "success"}, 200), + ( + {"update_mask": ["state"]}, + {"state": DagRunState.SUCCESS}, + {"state": "success"}, + 200, + ), ( {"update_mask": ["note"]}, {"state": DagRunState.FAILED, "note": "new_note1"}, @@ -507,7 +891,9 @@ def test_patch_dag_run_with_update_mask( self, test_client, query_params, patch_body, response_body, expected_status_code ): response = test_client.patch( - f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", params=query_params, json=patch_body + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", + params=query_params, + json=patch_body, ) response_json = response.json() assert response.status_code == expected_status_code @@ -516,7 +902,8 @@ def test_patch_dag_run_with_update_mask( def test_patch_dag_run_not_found(self, test_client): response = test_client.patch( - f"/public/dags/{DAG1_ID}/dagRuns/invalid", json={"state": DagRunState.SUCCESS} + f"/public/dags/{DAG1_ID}/dagRuns/invalid", + json={"state": DagRunState.SUCCESS}, ) assert response.status_code == 404 body = response.json() @@ -618,7 +1005,8 @@ def test_should_respond_404(self, test_client): class TestClearDagRun: def test_clear_dag_run(self, test_client): response = test_client.post( - f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}/clear", json={"dry_run": False} + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}/clear", + json={"dry_run": False}, ) assert response.status_code == 200 body = response.json()