diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index dadfb3e4f42f6..985efc7fc898d 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -305,6 +305,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries)) +@mark_fastapi_migration_done @security.requires_access_dag("POST", DagAccessEntity.RUN) @action_logging @provide_session diff --git a/airflow/api_fastapi/core_api/datamodels/dag_run.py b/airflow/api_fastapi/core_api/datamodels/dag_run.py index 55240d15e55ff..ab8126277873e 100644 --- a/airflow/api_fastapi/core_api/datamodels/dag_run.py +++ b/airflow/api_fastapi/core_api/datamodels/dag_run.py @@ -20,9 +20,11 @@ from datetime import datetime from enum import Enum -from pydantic import AwareDatetime, Field, NonNegativeInt +from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field, model_validator from airflow.api_fastapi.core_api.base import BaseModel +from airflow.models import DagRun +from airflow.utils import timezone from airflow.utils.state import DagRunState from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -75,6 +77,37 @@ class DAGRunCollectionResponse(BaseModel): total_entries: int +class TriggerDAGRunPostBody(BaseModel): + """Trigger DAG Run Serializer for POST body.""" + + dag_run_id: str | None = None + data_interval_start: AwareDatetime | None = None + data_interval_end: AwareDatetime | None = None + + conf: dict = Field(default_factory=dict) + note: str | None = None + + @model_validator(mode="after") + def check_data_intervals(cls, values): + if (values.data_interval_start is None) != (values.data_interval_end is None): + raise ValueError( + "Either both data_interval_start and data_interval_end must be provided or both must be None" + ) + return values + + @model_validator(mode="after") + def validate_dag_run_id(self): + if not self.dag_run_id: + self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.logical_date) + return self + + # Mypy issue https://github.com/python/mypy/issues/1362 + @computed_field # type: ignore[misc] + @property + def logical_date(self) -> datetime: + return timezone.utcnow() + + class DAGRunsBatchBody(BaseModel): """List DAG Runs body for batch endpoint.""" diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index c53f68a8438af..46fd0382eb982 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1828,6 +1828,67 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPValidationError' + post: + tags: + - DagRun + summary: Trigger Dag Run + description: Trigger a DAG. + operationId: trigger_dag_run + parameters: + - name: dag_id + in: path + required: true + schema: + title: Dag Id + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/TriggerDAGRunPostBody' + responses: + '200': + description: Successful Response + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRunResponse' + '401': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Unauthorized + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '404': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Not Found + '409': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Conflict + '422': + description: Validation Error + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' /public/dags/{dag_id}/dagRuns/list: post: tags: @@ -8672,6 +8733,36 @@ components: - microseconds title: TimeDelta description: TimeDelta can be used to interact with datetime.timedelta objects. + TriggerDAGRunPostBody: + properties: + dag_run_id: + anyOf: + - type: string + - type: 'null' + title: Dag Run Id + data_interval_start: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Data Interval Start + data_interval_end: + anyOf: + - type: string + format: date-time + - type: 'null' + title: Data Interval End + conf: + type: object + title: Conf + note: + anyOf: + - type: string + - type: 'null' + title: Note + type: object + title: TriggerDAGRunPostBody + description: Trigger DAG Run Serializer for POST body. TriggerResponse: properties: id: diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 5b3eb0e66e505..1e95f75273c16 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -19,6 +19,7 @@ from typing import Annotated, Literal, cast +import pendulum from fastapi import Depends, HTTPException, Query, Request, status from sqlalchemy import select from sqlalchemy.orm import Session @@ -50,13 +51,19 @@ DAGRunPatchStates, DAGRunResponse, DAGRunsBatchBody, + TriggerDAGRunPostBody, ) from airflow.api_fastapi.core_api.datamodels.task_instances import ( TaskInstanceCollectionResponse, TaskInstanceResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.models import DAG, DagRun +from airflow.exceptions import ParamValidationError +from airflow.models import DAG, DagModel, DagRun +from airflow.models.dag_version import DagVersion +from airflow.timetables.base import DataInterval +from airflow.utils.state import DagRunState +from airflow.utils.types import DagRunTriggeredByType, DagRunType dag_run_router = AirflowRouter(tags=["DagRun"], prefix="/dags/{dag_id}/dagRuns") @@ -303,6 +310,67 @@ def get_dag_runs( ) +@dag_run_router.post( + "", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_404_NOT_FOUND, + status.HTTP_409_CONFLICT, + ] + ), +) +def trigger_dag_run( + dag_id, body: TriggerDAGRunPostBody, request: Request, session: Annotated[Session, Depends(get_session)] +) -> DAGRunResponse: + """Trigger a DAG.""" + dm = session.scalar(select(DagModel).where(DagModel.is_active, DagModel.dag_id == dag_id).limit(1)) + if not dm: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: '{dag_id}' not found") + + if dm.has_import_errors: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + f"DAG with dag_id: '{dag_id}' has import errors and cannot be triggered", + ) + + run_id = body.dag_run_id + logical_date = pendulum.instance(body.logical_date) + + try: + dag: DAG = request.app.state.dag_bag.get_dag(dag_id) + + if body.data_interval_start and body.data_interval_end: + data_interval = DataInterval( + start=pendulum.instance(body.data_interval_start), + end=pendulum.instance(body.data_interval_end), + ) + else: + data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date) + dag_version = DagVersion.get_latest_version(dag.dag_id) + dag_run = dag.create_dagrun( + run_type=DagRunType.MANUAL, + run_id=run_id, + logical_date=logical_date, + data_interval=data_interval, + state=DagRunState.QUEUED, + conf=body.conf, + external_trigger=True, + dag_version=dag_version, + session=session, + triggered_by=DagRunTriggeredByType.REST_API, + ) + dag_run_note = body.note + if dag_run_note: + current_user_id = None # refer to https://github.com/apache/airflow/issues/43534 + dag_run.note = (dag_run_note, current_user_id) + return dag_run + except ValueError as e: + raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e)) + except ParamValidationError as e: + raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e)) + + @dag_run_router.post("/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) def get_list_dag_runs_batch( dag_id: Literal["~"], body: DAGRunsBatchBody, session: Annotated[Session, Depends(get_session)] diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index e4bc5f300d8a5..cb7f5c7a53724 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -1573,6 +1573,9 @@ export type ConnectionServiceTestConnectionMutationResult = Awaited< export type DagRunServiceClearDagRunMutationResult = Awaited< ReturnType >; +export type DagRunServiceTriggerDagRunMutationResult = Awaited< + ReturnType +>; export type DagRunServiceGetListDagRunsBatchMutationResult = Awaited< ReturnType >; diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index 6ff3e83ccced2..d644beb729488 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -48,6 +48,7 @@ import { PoolPostBody, PoolPostBulkBody, TaskInstancesBatchBody, + TriggerDAGRunPostBody, VariableBody, } from "../requests/types.gen"; import * as Common from "./common"; @@ -2726,6 +2727,49 @@ export const useDagRunServiceClearDagRun = < }) as unknown as Promise, ...options, }); +/** + * Trigger Dag Run + * Trigger a DAG. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns DAGRunResponse Successful Response + * @throws ApiError + */ +export const useDagRunServiceTriggerDagRun = < + TData = Common.DagRunServiceTriggerDagRunMutationResult, + TError = unknown, + TContext = unknown, +>( + options?: Omit< + UseMutationOptions< + TData, + TError, + { + dagId: unknown; + requestBody: TriggerDAGRunPostBody; + }, + TContext + >, + "mutationFn" + >, +) => + useMutation< + TData, + TError, + { + dagId: unknown; + requestBody: TriggerDAGRunPostBody; + }, + TContext + >({ + mutationFn: ({ dagId, requestBody }) => + DagRunService.triggerDagRun({ + dagId, + requestBody, + }) as unknown as Promise, + ...options, + }); /** * Get List Dag Runs Batch * Get a list of DAG Runs. diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 8002b9d37f6c0..f657ba8d4e2d8 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -4852,6 +4852,64 @@ export const $TimeDelta = { "TimeDelta can be used to interact with datetime.timedelta objects.", } as const; +export const $TriggerDAGRunPostBody = { + properties: { + dag_run_id: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Dag Run Id", + }, + data_interval_start: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Data Interval Start", + }, + data_interval_end: { + anyOf: [ + { + type: "string", + format: "date-time", + }, + { + type: "null", + }, + ], + title: "Data Interval End", + }, + conf: { + type: "object", + title: "Conf", + }, + note: { + anyOf: [ + { + type: "string", + }, + { + type: "null", + }, + ], + title: "Note", + }, + }, + type: "object", + title: "TriggerDAGRunPostBody", + description: "Trigger DAG Run Serializer for POST body.", +} as const; + export const $TriggerResponse = { properties: { id: { diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index d8cb33bceec9f..4c8944447cec2 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -70,6 +70,8 @@ import type { ClearDagRunResponse, GetDagRunsData, GetDagRunsResponse, + TriggerDagRunData, + TriggerDagRunResponse, GetListDagRunsBatchData, GetListDagRunsBatchResponse, GetDagSourceData, @@ -1193,6 +1195,37 @@ export class DagRunService { }); } + /** + * Trigger Dag Run + * Trigger a DAG. + * @param data The data for the request. + * @param data.dagId + * @param data.requestBody + * @returns DAGRunResponse Successful Response + * @throws ApiError + */ + public static triggerDagRun( + data: TriggerDagRunData, + ): CancelablePromise { + return __request(OpenAPI, { + method: "POST", + url: "/public/dags/{dag_id}/dagRuns", + path: { + dag_id: data.dagId, + }, + body: data.requestBody, + mediaType: "application/json", + errors: { + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 409: "Conflict", + 422: "Validation Error", + }, + }); + } + /** * Get List Dag Runs Batch * Get a list of DAG Runs. diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index bdcce0157dce2..e15de90441bb0 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -1126,6 +1126,19 @@ export type TimeDelta = { microseconds: number; }; +/** + * Trigger DAG Run Serializer for POST body. + */ +export type TriggerDAGRunPostBody = { + dag_run_id?: string | null; + data_interval_start?: string | null; + data_interval_end?: string | null; + conf?: { + [key: string]: unknown; + }; + note?: string | null; +}; + /** * Trigger serializer for responses. */ @@ -1494,6 +1507,13 @@ export type GetDagRunsData = { export type GetDagRunsResponse = DAGRunCollectionResponse; +export type TriggerDagRunData = { + dagId: unknown; + requestBody: TriggerDAGRunPostBody; +}; + +export type TriggerDagRunResponse = DAGRunResponse; + export type GetListDagRunsBatchData = { dagId: "~"; requestBody: DAGRunsBatchBody; @@ -2853,6 +2873,39 @@ export type $OpenApiTs = { 422: HTTPValidationError; }; }; + post: { + req: TriggerDagRunData; + res: { + /** + * Successful Response + */ + 200: DAGRunResponse; + /** + * Bad Request + */ + 400: HTTPExceptionResponse; + /** + * Unauthorized + */ + 401: HTTPExceptionResponse; + /** + * Forbidden + */ + 403: HTTPExceptionResponse; + /** + * Not Found + */ + 404: HTTPExceptionResponse; + /** + * Conflict + */ + 409: HTTPExceptionResponse; + /** + * Validation Error + */ + 422: HTTPValidationError; + }; + }; }; "/public/dags/{dag_id}/dagRuns/list": { post: { diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 6e9f2b69eb291..d453c973c8d0f 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -17,15 +17,19 @@ from __future__ import annotations -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta +from unittest import mock import pytest +import time_machine from sqlalchemy import select -from airflow.models import DagRun +from airflow.models import DagModel, DagRun from airflow.models.asset import AssetEvent, AssetModel +from airflow.models.param import Param from airflow.operators.empty import EmptyOperator from airflow.sdk.definitions.asset import Asset +from airflow.utils import timezone from airflow.utils.session import provide_session from airflow.utils.state import DagRunState, State from airflow.utils.types import DagRunTriggeredByType, DagRunType @@ -63,20 +67,24 @@ LOGICAL_DATE3 = datetime(2024, 5, 16, 0, 0, tzinfo=timezone.utc) LOGICAL_DATE4 = datetime(2024, 5, 25, 0, 0, tzinfo=timezone.utc) DAG1_RUN1_NOTE = "test_note" +DAG2_PARAM = {"validated_number": Param(1, minimum=1, maximum=10)} DAG_RUNS_LIST = [DAG1_RUN1_ID, DAG1_RUN2_ID, DAG2_RUN1_ID, DAG2_RUN2_ID] @pytest.fixture(autouse=True) @provide_session -def setup(dag_maker, session=None): +def setup(request, dag_maker, session=None): clear_db_runs() clear_db_dags() clear_db_serialized_dags() + if "no_setup" in request.keywords: + return + with dag_maker( DAG1_ID, - schedule="@daily", + schedule=None, start_date=START_DATE1, ): task1 = EmptyOperator(task_id="task_1") @@ -102,11 +110,7 @@ def setup(dag_maker, session=None): logical_date=LOGICAL_DATE2, ) - with dag_maker( - DAG2_ID, - schedule=None, - start_date=START_DATE2, - ): + with dag_maker(DAG2_ID, schedule=None, start_date=START_DATE2, params=DAG2_PARAM): EmptyOperator(task_id="task_2") dag_maker.create_dagrun( run_id=DAG2_RUN1_ID, @@ -1048,3 +1052,308 @@ def test_clear_dag_run_unprocessable_entity(self, test_client): body = response.json() assert body["detail"][0]["msg"] == "Field required" assert body["detail"][0]["loc"][0] == "body" + + +class TestTriggerDagRun: + def _dags_for_trigger_tests(self, session=None): + inactive_dag = DagModel( + dag_id="inactive", + fileloc="/tmp/dag_del_1.py", + timetable_summary="2 2 * * *", + is_active=False, + is_paused=True, + owners="test_owner,another_test_owner", + next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + ) + + import_errors_dag = DagModel( + dag_id="import_errors", + fileloc="/tmp/dag_del_2.py", + timetable_summary="2 2 * * *", + is_active=True, + owners="test_owner,another_test_owner", + next_dagrun=datetime(2021, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + ) + import_errors_dag.has_import_errors = True + + session.add(inactive_dag) + session.add(import_errors_dag) + session.commit() + + @time_machine.travel(timezone.utcnow(), tick=False) + @pytest.mark.parametrize( + "dag_run_id, note, data_interval_start, data_interval_end", + [ + ("dag_run_5", "test-note", None, None), + ( + "dag_run_6", + "test-note", + "2024-01-03T00:00:00+00:00", + "2024-01-04T05:00:00+00:00", + ), + (None, None, None, None), + ], + ) + def test_should_respond_200( + self, + test_client, + dag_run_id, + note, + data_interval_start, + data_interval_end, + ): + fixed_now = timezone.utcnow().isoformat() + + request_json = {"note": note} + if dag_run_id is not None: + request_json["dag_run_id"] = dag_run_id + if data_interval_start is not None: + request_json["data_interval_start"] = data_interval_start + if data_interval_end is not None: + request_json["data_interval_end"] = data_interval_end + + response = test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json=request_json, + ) + assert response.status_code == 200 + + if dag_run_id is None: + expected_dag_run_id = f"manual__{fixed_now}" + else: + expected_dag_run_id = dag_run_id + + expected_data_interval_start = fixed_now.replace("+00:00", "Z") + expected_data_interval_end = fixed_now.replace("+00:00", "Z") + if data_interval_start is not None and data_interval_end is not None: + expected_data_interval_start = data_interval_start.replace("+00:00", "Z") + expected_data_interval_end = data_interval_end.replace("+00:00", "Z") + + expected_response_json = { + "conf": {}, + "dag_id": DAG1_ID, + "dag_run_id": expected_dag_run_id, + "end_date": None, + "logical_date": fixed_now.replace("+00:00", "Z"), + "external_trigger": True, + "start_date": None, + "state": "queued", + "data_interval_end": expected_data_interval_end, + "data_interval_start": expected_data_interval_start, + "queued_at": fixed_now.replace("+00:00", "Z"), + "last_scheduling_decision": None, + "run_type": "manual", + "note": note, + "triggered_by": "rest_api", + } + + assert response.json() == expected_response_json + + @pytest.mark.parametrize( + "post_body, expected_detail", + [ + # Uncomment these 2 test cases once https://github.com/apache/airflow/pull/44306 is merged + # ( + # {"executiondate": "2020-11-10T08:25:56Z"}, + # { + # "detail": [ + # { + # "input": "2020-11-10T08:25:56Z", + # "loc": ["body", "executiondate"], + # "msg": "Extra inputs are not permitted", + # "type": "extra_forbidden", + # } + # ] + # }, + # ), + # ( + # {"logical_date": "2020-11-10T08:25:56"}, + # { + # "detail": [ + # { + # "input": "2020-11-10T08:25:56", + # "loc": ["body", "logical_date"], + # "msg": "Extra inputs are not permitted", + # "type": "extra_forbidden", + # } + # ] + # }, + # ), + ( + {"data_interval_start": "2020-11-10T08:25:56"}, + { + "detail": [ + { + "input": "2020-11-10T08:25:56", + "loc": ["body", "data_interval_start"], + "msg": "Input should have timezone info", + "type": "timezone_aware", + } + ] + }, + ), + ( + {"data_interval_end": "2020-11-10T08:25:56"}, + { + "detail": [ + { + "input": "2020-11-10T08:25:56", + "loc": ["body", "data_interval_end"], + "msg": "Input should have timezone info", + "type": "timezone_aware", + } + ] + }, + ), + ( + {"dag_run_id": 20}, + { + "detail": [ + { + "input": 20, + "loc": ["body", "dag_run_id"], + "msg": "Input should be a valid string", + "type": "string_type", + } + ] + }, + ), + ( + {"note": 20}, + { + "detail": [ + { + "input": 20, + "loc": ["body", "note"], + "msg": "Input should be a valid string", + "type": "string_type", + } + ] + }, + ), + ( + {"conf": 20}, + { + "detail": [ + { + "input": 20, + "loc": ["body", "conf"], + "msg": "Input should be a valid dictionary", + "type": "dict_type", + } + ] + }, + ), + ], + ) + def test_invalid_data(self, test_client, post_body, expected_detail): + response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json=post_body) + assert response.status_code == 422 + assert response.json() == expected_detail + + @mock.patch("airflow.models.DAG.create_dagrun") + def test_dagrun_creation_exception_is_handled(self, mock_create_dagrun, test_client): + error_message = "Encountered Error" + + mock_create_dagrun.side_effect = ValueError(error_message) + + response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json={}) + assert response.status_code == 400 + assert response.json() == {"detail": error_message} + + def test_should_respond_404_if_a_dag_is_inactive(self, test_client, session): + self._dags_for_trigger_tests(session) + response = test_client.post("/public/dags/inactive/dagRuns", json={}) + assert response.status_code == 404 + assert response.json()["detail"] == "DAG with dag_id: 'inactive' not found" + + def test_should_respond_400_if_a_dag_has_import_errors(self, test_client, session): + self._dags_for_trigger_tests(session) + response = test_client.post("/public/dags/import_errors/dagRuns", json={}) + assert response.status_code == 400 + assert ( + response.json()["detail"] + == "DAG with dag_id: 'import_errors' has import errors and cannot be triggered" + ) + + @time_machine.travel(timezone.utcnow(), tick=False) + def test_should_response_200_for_duplicate_logical_date(self, test_client): + RUN_ID_1 = "random_1" + RUN_ID_2 = "random_2" + now = timezone.utcnow().isoformat().replace("+00:00", "Z") + note = "duplicate logical date test" + response_1 = test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json={"dag_run_id": RUN_ID_1, "note": note}, + ) + response_2 = test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json={"dag_run_id": RUN_ID_2, "note": note}, + ) + + assert response_1.status_code == response_2.status_code == 200 + body1 = response_1.json() + body2 = response_2.json() + + for each_run_id, each_body in [(RUN_ID_1, body1), (RUN_ID_2, body2)]: + assert each_body == { + "dag_run_id": each_run_id, + "dag_id": DAG1_ID, + "logical_date": now, + "queued_at": now, + "start_date": None, + "end_date": None, + "data_interval_start": now, + "data_interval_end": now, + "last_scheduling_decision": None, + "run_type": "manual", + "state": "queued", + "external_trigger": True, + "triggered_by": "rest_api", + "conf": {}, + "note": note, + } + + @pytest.mark.parametrize( + "data_interval_start, data_interval_end", + [ + ( + LOGICAL_DATE1.isoformat(), + None, + ), + ( + None, + LOGICAL_DATE1.isoformat(), + ), + ], + ) + def test_should_response_422_for_missing_start_date_or_end_date( + self, test_client, data_interval_start, data_interval_end + ): + response = test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json={"data_interval_start": data_interval_start, "data_interval_end": data_interval_end}, + ) + assert response.status_code == 422 + assert ( + response.json()["detail"][0]["msg"] + == "Value error, Either both data_interval_start and data_interval_end must be provided or both must be None" + ) + + def test_raises_validation_error_for_invalid_params(self, test_client): + response = test_client.post( + f"/public/dags/{DAG2_ID}/dagRuns", + json={"conf": {"validated_number": 5000}}, + ) + assert response.status_code == 400 + assert "Invalid input for param validated_number" in response.json()["detail"] + + def test_response_404(self, test_client): + response = test_client.post("/public/dags/randoms/dagRuns", json={}) + assert response.status_code == 404 + assert response.json()["detail"] == "DAG with dag_id: 'randoms' not found" + + def test_response_409(self, test_client): + response = test_client.post(f"/public/dags/{DAG1_ID}/dagRuns", json={"dag_run_id": DAG1_RUN1_ID}) + assert response.status_code == 409 + assert response.json()["detail"] == "Unique constraint violation"