Skip to content

Commit

Permalink
AIP-84 Add ability to update dag run note in PATCH dag_run endpoint (a…
Browse files Browse the repository at this point in the history
…pache#43508)

* include dag_run_note in update_mask

* add dag run note

* state can be none

* add test

* Fix tests

* handle edge cases

* add tests

* remove joinedload

* fix update_mask checks

* fix tests

* fix

* remove async

* undo async

* fix

---------

Co-authored-by: pierrejeambrun <[email protected]>
  • Loading branch information
rawwar and pierrejeambrun authored Nov 8, 2024
1 parent cd323e2 commit f57db71
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 54 deletions.
16 changes: 11 additions & 5 deletions airflow/api_fastapi/core_api/openapi/v1-generated.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1318,9 +1318,9 @@ paths:
patch:
tags:
- DagRun
summary: Patch Dag Run State
summary: Patch Dag Run
description: Modify a DAG Run.
operationId: patch_dag_run_state
operationId: patch_dag_run
parameters:
- name: dag_id
in: path
Expand Down Expand Up @@ -3694,10 +3694,16 @@ components:
DAGRunPatchBody:
properties:
state:
$ref: '#/components/schemas/DAGRunPatchStates'
anyOf:
- $ref: '#/components/schemas/DAGRunPatchStates'
- type: 'null'
note:
anyOf:
- type: string
maxLength: 1000
- type: 'null'
title: Note
type: object
required:
- state
title: DAGRunPatchBody
description: DAG Run Serializer for PATCH requests.
DAGRunPatchStates:
Expand Down
34 changes: 20 additions & 14 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def delete_dag_run(dag_id: str, dag_run_id: str, session: Annotated[Session, Dep
]
),
)
def patch_dag_run_state(
def patch_dag_run(
dag_id: str,
dag_run_id: str,
patch_body: DAGRunPatchBody,
Expand All @@ -121,23 +121,29 @@ def patch_dag_run_state(
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found")

if update_mask:
if update_mask != ["state"]:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Only `state` field can be updated through the REST API"
)
data = patch_body.model_dump(include=set(update_mask))
else:
update_mask = ["state"]
data = patch_body.model_dump()

for attr_name in update_mask:
for attr_name, attr_value in data.items():
if attr_name == "state":
state = getattr(patch_body, attr_name)
if state == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True)
attr_value = getattr(patch_body, "state")
if attr_value == DAGRunPatchStates.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.QUEUED:
set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_value == DAGRunPatchStates.FAILED:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True, session=session)
elif attr_name == "note":
# Once Authentication is implemented in this FastAPI app,
# user id will be added when updating dag run note
# Refer to https://github.com/apache/airflow/issues/43534
dag_run = session.get(DagRun, dag_run.id)
if dag_run.dag_run_note is None:
dag_run.note = (attr_value, None)
else:
set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True)
dag_run.dag_run_note.content = attr_value

session.refresh(dag_run)
dag_run = session.get(DagRun, dag_run.id)

return DAGRunResponse.model_validate(dag_run, from_attributes=True)
3 changes: 2 additions & 1 deletion airflow/api_fastapi/core_api/serializers/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class DAGRunPatchStates(str, Enum):
class DAGRunPatchBody(BaseModel):
"""DAG Run Serializer for PATCH requests."""

state: DAGRunPatchStates
state: DAGRunPatchStates | None = None
note: str | None = Field(None, max_length=1000)


class DAGRunResponse(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions airflow/ui/openapi-gen/queries/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,8 @@ export type DagServicePatchDagMutationResult = Awaited<
export type ConnectionServicePatchConnectionMutationResult = Awaited<
ReturnType<typeof ConnectionService.patchConnection>
>;
export type DagRunServicePatchDagRunStateMutationResult = Awaited<
ReturnType<typeof DagRunService.patchDagRunState>
export type DagRunServicePatchDagRunMutationResult = Awaited<
ReturnType<typeof DagRunService.patchDagRun>
>;
export type PoolServicePatchPoolMutationResult = Awaited<
ReturnType<typeof PoolService.patchPool>
Expand Down
8 changes: 4 additions & 4 deletions airflow/ui/openapi-gen/queries/queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1912,7 +1912,7 @@ export const useConnectionServicePatchConnection = <
...options,
});
/**
* Patch Dag Run State
* Patch Dag Run
* Modify a DAG Run.
* @param data The data for the request.
* @param data.dagId
Expand All @@ -1922,8 +1922,8 @@ export const useConnectionServicePatchConnection = <
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
export const useDagRunServicePatchDagRunState = <
TData = Common.DagRunServicePatchDagRunStateMutationResult,
export const useDagRunServicePatchDagRun = <
TData = Common.DagRunServicePatchDagRunMutationResult,
TError = unknown,
TContext = unknown,
>(
Expand Down Expand Up @@ -1954,7 +1954,7 @@ export const useDagRunServicePatchDagRunState = <
TContext
>({
mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) =>
DagRunService.patchDagRunState({
DagRunService.patchDagRun({
dagId,
dagRunId,
requestBody,
Expand Down
22 changes: 20 additions & 2 deletions airflow/ui/openapi-gen/requests/schemas.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -981,11 +981,29 @@ export const $DAGResponse = {
export const $DAGRunPatchBody = {
properties: {
state: {
$ref: "#/components/schemas/DAGRunPatchStates",
anyOf: [
{
$ref: "#/components/schemas/DAGRunPatchStates",
},
{
type: "null",
},
],
},
note: {
anyOf: [
{
type: "string",
maxLength: 1000,
},
{
type: "null",
},
],
title: "Note",
},
},
type: "object",
required: ["state"],
title: "DAGRunPatchBody",
description: "DAG Run Serializer for PATCH requests.",
} as const;
Expand Down
12 changes: 6 additions & 6 deletions airflow/ui/openapi-gen/requests/services.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ import type {
GetDagRunResponse,
DeleteDagRunData,
DeleteDagRunResponse,
PatchDagRunStateData,
PatchDagRunStateResponse,
PatchDagRunData,
PatchDagRunResponse,
GetDagSourceData,
GetDagSourceResponse,
GetEventLogData,
Expand Down Expand Up @@ -794,7 +794,7 @@ export class DagRunService {
}

/**
* Patch Dag Run State
* Patch Dag Run
* Modify a DAG Run.
* @param data The data for the request.
* @param data.dagId
Expand All @@ -804,9 +804,9 @@ export class DagRunService {
* @returns DAGRunResponse Successful Response
* @throws ApiError
*/
public static patchDagRunState(
data: PatchDagRunStateData,
): CancelablePromise<PatchDagRunStateResponse> {
public static patchDagRun(
data: PatchDagRunData,
): CancelablePromise<PatchDagRunResponse> {
return __request(OpenAPI, {
method: "PATCH",
url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}",
Expand Down
9 changes: 5 additions & 4 deletions airflow/ui/openapi-gen/requests/types.gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ export type DAGResponse = {
* DAG Run Serializer for PATCH requests.
*/
export type DAGRunPatchBody = {
state: DAGRunPatchStates;
state?: DAGRunPatchStates | null;
note?: string | null;
};

/**
Expand Down Expand Up @@ -932,14 +933,14 @@ export type DeleteDagRunData = {

export type DeleteDagRunResponse = void;

export type PatchDagRunStateData = {
export type PatchDagRunData = {
dagId: string;
dagRunId: string;
requestBody: DAGRunPatchBody;
updateMask?: Array<string> | null;
};

export type PatchDagRunStateResponse = DAGRunResponse;
export type PatchDagRunResponse = DAGRunResponse;

export type GetDagSourceData = {
accept?: string;
Expand Down Expand Up @@ -1775,7 +1776,7 @@ export type $OpenApiTs = {
};
};
patch: {
req: PatchDagRunStateData;
req: PatchDagRunData;
res: {
/**
* Successful Response
Expand Down
88 changes: 72 additions & 16 deletions tests/api_fastapi/core_api/routes/public/test_dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API
START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc)
EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc)
DAG1_NOTE = "test_note"
DAG1_RUN1_NOTE = "test_note"


@pytest.fixture(autouse=True)
Expand All @@ -66,13 +66,13 @@ def setup(dag_maker, session=None):
start_date=START_DATE,
):
EmptyOperator(task_id="task_1")
dag1 = dag_maker.create_dagrun(
dag_run1 = dag_maker.create_dagrun(
run_id=DAG1_RUN1_ID,
state=DAG1_RUN1_STATE,
run_type=DAG1_RUN1_RUN_TYPE,
triggered_by=DAG1_RUN1_TRIGGERED_BY,
)
dag1.note = (DAG1_NOTE, 1)
dag_run1.note = (DAG1_RUN1_NOTE, 1)

dag_maker.create_dagrun(
run_id=DAG1_RUN2_ID,
Expand Down Expand Up @@ -114,7 +114,14 @@ class TestGetDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, run_type, triggered_by, dag_run_note",
[
(DAG1_ID, DAG1_RUN1_ID, DAG1_RUN1_STATE, DAG1_RUN1_RUN_TYPE, DAG1_RUN1_TRIGGERED_BY, DAG1_NOTE),
(
DAG1_ID,
DAG1_RUN1_ID,
DAG1_RUN1_STATE,
DAG1_RUN1_RUN_TYPE,
DAG1_RUN1_TRIGGERED_BY,
DAG1_RUN1_NOTE,
),
(DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, DAG1_RUN2_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, DAG2_RUN1_TRIGGERED_BY, None),
(DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, DAG2_RUN2_TRIGGERED_BY, None),
Expand All @@ -140,36 +147,85 @@ def test_get_dag_run_not_found(self, test_client):

class TestPatchDagRun:
@pytest.mark.parametrize(
"dag_id, run_id, state, response_state",
"dag_id, run_id, patch_body, response_body",
[
(DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
(DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
(DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
(
DAG1_ID,
DAG1_RUN1_ID,
{"state": DagRunState.FAILED, "note": "new_note2"},
{"state": DagRunState.FAILED, "note": "new_note2"},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"state": DagRunState.SUCCESS},
{"state": DagRunState.SUCCESS, "note": None},
),
(
DAG2_ID,
DAG2_RUN1_ID,
{"state": DagRunState.QUEUED},
{"state": DagRunState.QUEUED, "note": None},
),
(
DAG1_ID,
DAG1_RUN1_ID,
{"note": "updated note"},
{"state": DagRunState.SUCCESS, "note": "updated note"},
),
(
DAG1_ID,
DAG1_RUN2_ID,
{"note": "new note", "state": DagRunState.FAILED},
{"state": DagRunState.FAILED, "note": "new note"},
),
(DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": DagRunState.FAILED, "note": None}),
],
)
def test_patch_dag_run(self, test_client, dag_id, run_id, state, response_state):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": state})
def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_body):
response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json=patch_body)
assert response.status_code == 200
body = response.json()
assert body["dag_id"] == dag_id
assert body["run_id"] == run_id
assert body["state"] == response_state
assert body.get("state") == response_body.get("state")
assert body.get("note") == response_body.get("note")

@pytest.mark.parametrize(
"query_params, patch_body, expected_status_code",
"query_params, patch_body, response_body, expected_status_code",
[
({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200),
({}, {"state": DagRunState.SUCCESS}, 200),
({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400),
({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, {"state": "success"}, 200),
(
{"update_mask": ["note"]},
{"state": DagRunState.FAILED, "note": "new_note1"},
{"note": "new_note1", "state": "success"},
200,
),
(
{},
{"state": DagRunState.FAILED, "note": "new_note2"},
{"note": "new_note2", "state": "failed"},
200,
),
({"update_mask": ["note"]}, {}, {"state": "success", "note": None}, 200),
(
{"update_mask": ["random"]},
{"state": DagRunState.FAILED},
{"state": "success", "note": "test_note"},
200,
),
],
)
def test_patch_dag_run_with_update_mask(
self, test_client, query_params, patch_body, expected_status_code
self, test_client, query_params, patch_body, response_body, expected_status_code
):
response = test_client.patch(
f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", params=query_params, json=patch_body
)
response_json = response.json()
assert response.status_code == expected_status_code
for key, value in response_body.items():
assert response_json.get(key) == value

def test_patch_dag_run_not_found(self, test_client):
response = test_client.patch(
Expand Down

0 comments on commit f57db71

Please sign in to comment.