From f57db717a31b2735699504a6f087bf94ef82fc66 Mon Sep 17 00:00:00 2001 From: Kalyan R Date: Fri, 8 Nov 2024 18:38:57 +0530 Subject: [PATCH] AIP-84 Add ability to update dag run note in PATCH dag_run endpoint (#43508) * include dag_run_note in update_mask * add dag run note * state can be none * add test * Fix tests * handle edge cases * add tests * remove joinedload * fix update_mask checks * fix tests * fix * remove async * undo async * fix --------- Co-authored-by: pierrejeambrun --- .../core_api/openapi/v1-generated.yaml | 16 ++-- .../core_api/routes/public/dag_run.py | 34 ++++--- .../core_api/serializers/dag_run.py | 3 +- airflow/ui/openapi-gen/queries/common.ts | 4 +- airflow/ui/openapi-gen/queries/queries.ts | 8 +- .../ui/openapi-gen/requests/schemas.gen.ts | 22 ++++- .../ui/openapi-gen/requests/services.gen.ts | 12 +-- airflow/ui/openapi-gen/requests/types.gen.ts | 9 +- .../core_api/routes/public/test_dag_run.py | 88 +++++++++++++++---- 9 files changed, 142 insertions(+), 54 deletions(-) diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 0e9221444af2f..9b52f3bc00347 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -1318,9 +1318,9 @@ paths: patch: tags: - DagRun - summary: Patch Dag Run State + summary: Patch Dag Run description: Modify a DAG Run. - operationId: patch_dag_run_state + operationId: patch_dag_run parameters: - name: dag_id in: path @@ -3694,10 +3694,16 @@ components: DAGRunPatchBody: properties: state: - $ref: '#/components/schemas/DAGRunPatchStates' + anyOf: + - $ref: '#/components/schemas/DAGRunPatchStates' + - type: 'null' + note: + anyOf: + - type: string + maxLength: 1000 + - type: 'null' + title: Note type: object - required: - - state title: DAGRunPatchBody description: DAG Run Serializer for PATCH requests. DAGRunPatchStates: diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index b05ed2ba11385..7778d7778fa17 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -99,7 +99,7 @@ def delete_dag_run(dag_id: str, dag_run_id: str, session: Annotated[Session, Dep ] ), ) -def patch_dag_run_state( +def patch_dag_run( dag_id: str, dag_run_id: str, patch_body: DAGRunPatchBody, @@ -121,23 +121,29 @@ def patch_dag_run_state( raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} was not found") if update_mask: - if update_mask != ["state"]: - raise HTTPException( - status.HTTP_400_BAD_REQUEST, "Only `state` field can be updated through the REST API" - ) + data = patch_body.model_dump(include=set(update_mask)) else: - update_mask = ["state"] + data = patch_body.model_dump() - for attr_name in update_mask: + for attr_name, attr_value in data.items(): if attr_name == "state": - state = getattr(patch_body, attr_name) - if state == DAGRunPatchStates.SUCCESS: - set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True) - elif state == DAGRunPatchStates.QUEUED: - set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True) + attr_value = getattr(patch_body, "state") + if attr_value == DAGRunPatchStates.SUCCESS: + set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True, session=session) + elif attr_value == DAGRunPatchStates.QUEUED: + set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, commit=True, session=session) + elif attr_value == DAGRunPatchStates.FAILED: + set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True, session=session) + elif attr_name == "note": + # Once Authentication is implemented in this FastAPI app, + # user id will be added when updating dag run note + # Refer to https://github.com/apache/airflow/issues/43534 + dag_run = session.get(DagRun, dag_run.id) + if dag_run.dag_run_note is None: + dag_run.note = (attr_value, None) else: - set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True) + dag_run.dag_run_note.content = attr_value - session.refresh(dag_run) + dag_run = session.get(DagRun, dag_run.id) return DAGRunResponse.model_validate(dag_run, from_attributes=True) diff --git a/airflow/api_fastapi/core_api/serializers/dag_run.py b/airflow/api_fastapi/core_api/serializers/dag_run.py index 15576905611c3..759c4399fbd70 100644 --- a/airflow/api_fastapi/core_api/serializers/dag_run.py +++ b/airflow/api_fastapi/core_api/serializers/dag_run.py @@ -37,7 +37,8 @@ class DAGRunPatchStates(str, Enum): class DAGRunPatchBody(BaseModel): """DAG Run Serializer for PATCH requests.""" - state: DAGRunPatchStates + state: DAGRunPatchStates | None = None + note: str | None = Field(None, max_length=1000) class DAGRunResponse(BaseModel): diff --git a/airflow/ui/openapi-gen/queries/common.ts b/airflow/ui/openapi-gen/queries/common.ts index 953bb16291fc7..cfbb945b4e57f 100644 --- a/airflow/ui/openapi-gen/queries/common.ts +++ b/airflow/ui/openapi-gen/queries/common.ts @@ -935,8 +935,8 @@ export type DagServicePatchDagMutationResult = Awaited< export type ConnectionServicePatchConnectionMutationResult = Awaited< ReturnType >; -export type DagRunServicePatchDagRunStateMutationResult = Awaited< - ReturnType +export type DagRunServicePatchDagRunMutationResult = Awaited< + ReturnType >; export type PoolServicePatchPoolMutationResult = Awaited< ReturnType diff --git a/airflow/ui/openapi-gen/queries/queries.ts b/airflow/ui/openapi-gen/queries/queries.ts index c6f4bd09dd697..3a8d508a8c426 100644 --- a/airflow/ui/openapi-gen/queries/queries.ts +++ b/airflow/ui/openapi-gen/queries/queries.ts @@ -1912,7 +1912,7 @@ export const useConnectionServicePatchConnection = < ...options, }); /** - * Patch Dag Run State + * Patch Dag Run * Modify a DAG Run. * @param data The data for the request. * @param data.dagId @@ -1922,8 +1922,8 @@ export const useConnectionServicePatchConnection = < * @returns DAGRunResponse Successful Response * @throws ApiError */ -export const useDagRunServicePatchDagRunState = < - TData = Common.DagRunServicePatchDagRunStateMutationResult, +export const useDagRunServicePatchDagRun = < + TData = Common.DagRunServicePatchDagRunMutationResult, TError = unknown, TContext = unknown, >( @@ -1954,7 +1954,7 @@ export const useDagRunServicePatchDagRunState = < TContext >({ mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) => - DagRunService.patchDagRunState({ + DagRunService.patchDagRun({ dagId, dagRunId, requestBody, diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts b/airflow/ui/openapi-gen/requests/schemas.gen.ts index 44bd279a1625a..b8c43b7ac2072 100644 --- a/airflow/ui/openapi-gen/requests/schemas.gen.ts +++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts @@ -981,11 +981,29 @@ export const $DAGResponse = { export const $DAGRunPatchBody = { properties: { state: { - $ref: "#/components/schemas/DAGRunPatchStates", + anyOf: [ + { + $ref: "#/components/schemas/DAGRunPatchStates", + }, + { + type: "null", + }, + ], + }, + note: { + anyOf: [ + { + type: "string", + maxLength: 1000, + }, + { + type: "null", + }, + ], + title: "Note", }, }, type: "object", - required: ["state"], title: "DAGRunPatchBody", description: "DAG Run Serializer for PATCH requests.", } as const; diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts b/airflow/ui/openapi-gen/requests/services.gen.ts index 45be069d310bc..8a6cd3e4f702a 100644 --- a/airflow/ui/openapi-gen/requests/services.gen.ts +++ b/airflow/ui/openapi-gen/requests/services.gen.ts @@ -49,8 +49,8 @@ import type { GetDagRunResponse, DeleteDagRunData, DeleteDagRunResponse, - PatchDagRunStateData, - PatchDagRunStateResponse, + PatchDagRunData, + PatchDagRunResponse, GetDagSourceData, GetDagSourceResponse, GetEventLogData, @@ -794,7 +794,7 @@ export class DagRunService { } /** - * Patch Dag Run State + * Patch Dag Run * Modify a DAG Run. * @param data The data for the request. * @param data.dagId @@ -804,9 +804,9 @@ export class DagRunService { * @returns DAGRunResponse Successful Response * @throws ApiError */ - public static patchDagRunState( - data: PatchDagRunStateData, - ): CancelablePromise { + public static patchDagRun( + data: PatchDagRunData, + ): CancelablePromise { return __request(OpenAPI, { method: "PATCH", url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}", diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts b/airflow/ui/openapi-gen/requests/types.gen.ts index 8dc0a3188ca3c..08f174e9b39fc 100644 --- a/airflow/ui/openapi-gen/requests/types.gen.ts +++ b/airflow/ui/openapi-gen/requests/types.gen.ts @@ -184,7 +184,8 @@ export type DAGResponse = { * DAG Run Serializer for PATCH requests. */ export type DAGRunPatchBody = { - state: DAGRunPatchStates; + state?: DAGRunPatchStates | null; + note?: string | null; }; /** @@ -932,14 +933,14 @@ export type DeleteDagRunData = { export type DeleteDagRunResponse = void; -export type PatchDagRunStateData = { +export type PatchDagRunData = { dagId: string; dagRunId: string; requestBody: DAGRunPatchBody; updateMask?: Array | null; }; -export type PatchDagRunStateResponse = DAGRunResponse; +export type PatchDagRunResponse = DAGRunResponse; export type GetDagSourceData = { accept?: string; @@ -1775,7 +1776,7 @@ export type $OpenApiTs = { }; }; patch: { - req: PatchDagRunStateData; + req: PatchDagRunData; res: { /** * Successful Response diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index 6c48cece798d2..64c3512e88b77 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -50,7 +50,7 @@ DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc) EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc) -DAG1_NOTE = "test_note" +DAG1_RUN1_NOTE = "test_note" @pytest.fixture(autouse=True) @@ -66,13 +66,13 @@ def setup(dag_maker, session=None): start_date=START_DATE, ): EmptyOperator(task_id="task_1") - dag1 = dag_maker.create_dagrun( + dag_run1 = dag_maker.create_dagrun( run_id=DAG1_RUN1_ID, state=DAG1_RUN1_STATE, run_type=DAG1_RUN1_RUN_TYPE, triggered_by=DAG1_RUN1_TRIGGERED_BY, ) - dag1.note = (DAG1_NOTE, 1) + dag_run1.note = (DAG1_RUN1_NOTE, 1) dag_maker.create_dagrun( run_id=DAG1_RUN2_ID, @@ -114,7 +114,14 @@ class TestGetDagRun: @pytest.mark.parametrize( "dag_id, run_id, state, run_type, triggered_by, dag_run_note", [ - (DAG1_ID, DAG1_RUN1_ID, DAG1_RUN1_STATE, DAG1_RUN1_RUN_TYPE, DAG1_RUN1_TRIGGERED_BY, DAG1_NOTE), + ( + DAG1_ID, + DAG1_RUN1_ID, + DAG1_RUN1_STATE, + DAG1_RUN1_RUN_TYPE, + DAG1_RUN1_TRIGGERED_BY, + DAG1_RUN1_NOTE, + ), (DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, DAG1_RUN2_TRIGGERED_BY, None), (DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, DAG2_RUN1_TRIGGERED_BY, None), (DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, DAG2_RUN2_TRIGGERED_BY, None), @@ -140,36 +147,85 @@ def test_get_dag_run_not_found(self, test_client): class TestPatchDagRun: @pytest.mark.parametrize( - "dag_id, run_id, state, response_state", + "dag_id, run_id, patch_body, response_body", [ - (DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED), - (DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS), - (DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED), + ( + DAG1_ID, + DAG1_RUN1_ID, + {"state": DagRunState.FAILED, "note": "new_note2"}, + {"state": DagRunState.FAILED, "note": "new_note2"}, + ), + ( + DAG1_ID, + DAG1_RUN2_ID, + {"state": DagRunState.SUCCESS}, + {"state": DagRunState.SUCCESS, "note": None}, + ), + ( + DAG2_ID, + DAG2_RUN1_ID, + {"state": DagRunState.QUEUED}, + {"state": DagRunState.QUEUED, "note": None}, + ), + ( + DAG1_ID, + DAG1_RUN1_ID, + {"note": "updated note"}, + {"state": DagRunState.SUCCESS, "note": "updated note"}, + ), + ( + DAG1_ID, + DAG1_RUN2_ID, + {"note": "new note", "state": DagRunState.FAILED}, + {"state": DagRunState.FAILED, "note": "new note"}, + ), + (DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": DagRunState.FAILED, "note": None}), ], ) - def test_patch_dag_run(self, test_client, dag_id, run_id, state, response_state): - response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": state}) + def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_body): + response = test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json=patch_body) assert response.status_code == 200 body = response.json() assert body["dag_id"] == dag_id assert body["run_id"] == run_id - assert body["state"] == response_state + assert body.get("state") == response_body.get("state") + assert body.get("note") == response_body.get("note") @pytest.mark.parametrize( - "query_params, patch_body, expected_status_code", + "query_params, patch_body, response_body, expected_status_code", [ - ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200), - ({}, {"state": DagRunState.SUCCESS}, 200), - ({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400), + ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, {"state": "success"}, 200), + ( + {"update_mask": ["note"]}, + {"state": DagRunState.FAILED, "note": "new_note1"}, + {"note": "new_note1", "state": "success"}, + 200, + ), + ( + {}, + {"state": DagRunState.FAILED, "note": "new_note2"}, + {"note": "new_note2", "state": "failed"}, + 200, + ), + ({"update_mask": ["note"]}, {}, {"state": "success", "note": None}, 200), + ( + {"update_mask": ["random"]}, + {"state": DagRunState.FAILED}, + {"state": "success", "note": "test_note"}, + 200, + ), ], ) def test_patch_dag_run_with_update_mask( - self, test_client, query_params, patch_body, expected_status_code + self, test_client, query_params, patch_body, response_body, expected_status_code ): response = test_client.patch( f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", params=query_params, json=patch_body ) + response_json = response.json() assert response.status_code == expected_status_code + for key, value in response_body.items(): + assert response_json.get(key) == value def test_patch_dag_run_not_found(self, test_client): response = test_client.patch(