Skip to content

Commit

Permalink
AIP-84 Migrate Trigger Dag Run endpoint to FastAPI (apache#43875)
Browse files Browse the repository at this point in the history
* init

* wip

* remove logical_date

* fix trigger dag_run

* tests WIP

* working tests

* remove logical_date from post body

* remove logical_date from tests

* fix

* include return type

* fix conf

* feedback

* fix tests

* Update tests/api_fastapi/core_api/routes/public/test_dag_run.py

* feedback
  • Loading branch information
rawwar authored Nov 27, 2024
1 parent 9ee501d commit a1fbdb3
Show file tree
Hide file tree
Showing 10 changed files with 704 additions and 11 deletions.
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@mark_fastapi_migration_done
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@action_logging
@provide_session
Expand Down
35 changes: 34 additions & 1 deletion airflow/api_fastapi/core_api/datamodels/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from datetime import datetime
from enum import Enum

from pydantic import AwareDatetime, Field, NonNegativeInt
from pydantic import AwareDatetime, Field, NonNegativeInt, computed_field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.models import DagRun
from airflow.utils import timezone
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

Expand Down Expand Up @@ -75,6 +77,37 @@ class DAGRunCollectionResponse(BaseModel):
total_entries: int


class TriggerDAGRunPostBody(BaseModel):
"""Trigger DAG Run Serializer for POST body."""

dag_run_id: str | None = None
data_interval_start: AwareDatetime | None = None
data_interval_end: AwareDatetime | None = None

conf: dict = Field(default_factory=dict)
note: str | None = None

@model_validator(mode="after")
def check_data_intervals(cls, values):
if (values.data_interval_start is None) != (values.data_interval_end is None):
raise ValueError(
"Either both data_interval_start and data_interval_end must be provided or both must be None"
)
return values

@model_validator(mode="after")
def validate_dag_run_id(self):
if not self.dag_run_id:
self.dag_run_id = DagRun.generate_run_id(DagRunType.MANUAL, self.logical_date)
return self

# Mypy issue https://github.com/python/mypy/issues/1362
@computed_field # type: ignore[misc]
@property
def logical_date(self) -> datetime:
return timezone.utcnow()


class DAGRunsBatchBody(BaseModel):
"""List DAG Runs body for batch endpoint."""

Expand Down
91 changes: 91 additions & 0 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1828,6 +1828,67 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
post:
tags:
- DagRun
summary: Trigger Dag Run
description: Trigger a DAG.
operationId: trigger_dag_run
parameters:
- name: dag_id
in: path
required: true
schema:
title: Dag Id
requestBody:
required: true
content:
application/json:
schema:
$ref: '#/components/schemas/TriggerDAGRunPostBody'
responses:
'200':
description: Successful Response
content:
application/json:
schema:
$ref: '#/components/schemas/DAGRunResponse'
'401':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Unauthorized
'403':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Forbidden
'400':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Bad Request
'404':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Not Found
'409':
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPExceptionResponse'
description: Conflict
'422':
description: Validation Error
content:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/list:
post:
tags:
Expand Down Expand Up @@ -8672,6 +8733,36 @@ components:
- microseconds
title: TimeDelta
description: TimeDelta can be used to interact with datetime.timedelta objects.
TriggerDAGRunPostBody:
properties:
dag_run_id:
anyOf:
- type: string
- type: 'null'
title: Dag Run Id
data_interval_start:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval Start
data_interval_end:
anyOf:
- type: string
format: date-time
- type: 'null'
title: Data Interval End
conf:
type: object
title: Conf
note:
anyOf:
- type: string
- type: 'null'
title: Note
type: object
title: TriggerDAGRunPostBody
description: Trigger DAG Run Serializer for POST body.
TriggerResponse:
properties:
id:
Expand Down
70 changes: 69 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from typing import Annotated, Literal, cast

import pendulum
from fastapi import Depends, HTTPException, Query, Request, status
from sqlalchemy import select
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -50,13 +51,19 @@
DAGRunPatchStates,
DAGRunResponse,
DAGRunsBatchBody,
TriggerDAGRunPostBody,
)
from airflow.api_fastapi.core_api.datamodels.task_instances import (
TaskInstanceCollectionResponse,
TaskInstanceResponse,
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.models import DAG, DagRun
from airflow.exceptions import ParamValidationError
from airflow.models import DAG, DagModel, DagRun
from airflow.models.dag_version import DagVersion
from airflow.timetables.base import DataInterval
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType

dag_run_router = AirflowRouter(tags=["DagRun"], prefix="/dags/{dag_id}/dagRuns")

Expand Down Expand Up @@ -303,6 +310,67 @@ def get_dag_runs(
)


@dag_run_router.post(
"",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_404_NOT_FOUND,
status.HTTP_409_CONFLICT,
]
),
)
def trigger_dag_run(
dag_id, body: TriggerDAGRunPostBody, request: Request, session: Annotated[Session, Depends(get_session)]
) -> DAGRunResponse:
"""Trigger a DAG."""
dm = session.scalar(select(DagModel).where(DagModel.is_active, DagModel.dag_id == dag_id).limit(1))
if not dm:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG with dag_id: '{dag_id}' not found")

if dm.has_import_errors:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"DAG with dag_id: '{dag_id}' has import errors and cannot be triggered",
)

run_id = body.dag_run_id
logical_date = pendulum.instance(body.logical_date)

try:
dag: DAG = request.app.state.dag_bag.get_dag(dag_id)

if body.data_interval_start and body.data_interval_end:
data_interval = DataInterval(
start=pendulum.instance(body.data_interval_start),
end=pendulum.instance(body.data_interval_end),
)
else:
data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
dag_version = DagVersion.get_latest_version(dag.dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
logical_date=logical_date,
data_interval=data_interval,
state=DagRunState.QUEUED,
conf=body.conf,
external_trigger=True,
dag_version=dag_version,
session=session,
triggered_by=DagRunTriggeredByType.REST_API,
)
dag_run_note = body.note
if dag_run_note:
current_user_id = None # refer to https://github.com/apache/airflow/issues/43534
dag_run.note = (dag_run_note, current_user_id)
return dag_run
except ValueError as e:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
except ParamValidationError as e:
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))


@dag_run_router.post("/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]))
def get_list_dag_runs_batch(
dag_id: Literal["~"], body: DAGRunsBatchBody, session: Annotated[Session, Depends(get_session)]
Expand Down
3 changes: 3 additions & 0 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1573,6 +1573,9 @@ export type ConnectionServiceTestConnectionMutationResult = Awaited<
export type DagRunServiceClearDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.clearDagRun>
>;
export type DagRunServiceTriggerDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.triggerDagRun>
>;
export type DagRunServiceGetListDagRunsBatchMutationResult = Awaited<
ReturnType<typeof DagRunService.getListDagRunsBatch>
>;
Expand Down
44 changes: 44 additions & 0 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import {
PoolPostBody,
PoolPostBulkBody,
TaskInstancesBatchBody,
TriggerDAGRunPostBody,
VariableBody,
} from "../requests/types.gen";
import * as Common from "./common";
Expand Down Expand Up @@ -2726,6 +2727,49 @@ export const useDagRunServiceClearDagRun = <
}) as unknown as Promise<TData>,
...options,
});
/**
* Trigger Dag Run
* Trigger a DAG.
* @param data The data for the request.
* @param data.dagId
* @param data.requestBody
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
export const useDagRunServiceTriggerDagRun = <
TData = Common.DagRunServiceTriggerDagRunMutationResult,
TError = unknown,
TContext = unknown,
>(
options?: Omit<
UseMutationOptions<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>,
"mutationFn"
>,
) =>
useMutation<
TData,
TError,
{
dagId: unknown;
requestBody: TriggerDAGRunPostBody;
},
TContext
>({
mutationFn: ({ dagId, requestBody }) =>
DagRunService.triggerDagRun({
dagId,
requestBody,
}) as unknown as Promise<TData>,
...options,
});
/**
* Get List Dag Runs Batch
* Get a list of DAG Runs.
Expand Down
58 changes: 58 additions & 0 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4852,6 +4852,64 @@ export const $TimeDelta = {
"TimeDelta can be used to interact with datetime.timedelta objects.",
} as const;

export const $TriggerDAGRunPostBody = {
properties: {
dag_run_id: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Dag Run Id",
},
data_interval_start: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval Start",
},
data_interval_end: {
anyOf: [
{
type: "string",
format: "date-time",
},
{
type: "null",
},
],
title: "Data Interval End",
},
conf: {
type: "object",
title: "Conf",
},
note: {
anyOf: [
{
type: "string",
},
{
type: "null",
},
],
title: "Note",
},
},
type: "object",
title: "TriggerDAGRunPostBody",
description: "Trigger DAG Run Serializer for POST body.",
} as const;

export const $TriggerResponse = {
properties: {
id: {
Expand Down
Loading

0 comments on commit a1fbdb3

Please sign in to comment.