From ff7e70056e5b4e6e61d20215eec00054415d4a77 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Fri, 13 Dec 2024 23:03:42 +0530 Subject: [PATCH] AIP-72: Extending SET RTIF endpoint to accept all JSONable types (#44843) An endpoint to set RTIF was added in #44359. This allowed only `dict[str, str]` entries to be passed down to the API which lead to issues when running tests with DAGs like: ```py from __future__ import annotations import sys import time from datetime import datetime from airflow import DAG from airflow.decorators import dag, task from airflow.operators.bash import BashOperator @dag( # every minute on the 30-second mark catchup=False, tags=[], schedule=None, start_date=datetime(2021, 1, 1), ) def hello_dag(): """ ### TaskFlow API Tutorial Documentation This is a simple data pipeline example which demonstrates the use of the TaskFlow API using three simple tasks for Extract, Transform, and Load. Documentation that goes along with the Airflow TaskFlow API tutorial is located [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html) """ @task() def hello(): print("hello") time.sleep(3) print("goodbye") print("err mesg", file=sys.stderr) hello() hello_dag() ``` The reason for this is that the arguments such as `op_args` and `op_kwargs` for PythonOperator can be non str. So that leads to a conclusion that we should accept `str` keys but `JsonAble` values. Some points to note for reviewers: 1. Type we store in the table: https://github.com/apache/airflow/blob/1eb683be3a79c80927e9af1e89dabb5e78ce3136/airflow/models/renderedtifields.py#L76. Hence we should be able to accept any JsonAble types and store them, for non JsonAble ones like tuple and set, we should convert them and do it. ### What does this PR change? - Get rid of the `RTIFPayload` and consume the payload directly in the api handler. - Handling special case of `tuples` - they are json serialisable but we used to store them as lists when passed as tuples, because of usage of json.dumps(). It has been made like this now: ``` def is_jsonable(x): try: json.dumps(x) if isinstance(x, tuple): # Tuple is converted to list in json.dumps # so while it is jsonable, it changes the type which might be a surprise # for the user, so instead we return False here -- which will convert it to string return False ``` - Reusing `serialize_template_field` from `airflow.serialization.helpers` because copy pasting code will be expensive, hard to maintain. We will revisit it anyways when we port the logic of templating to TASK SDK. Discussion: https://github.com/apache/airflow/pull/44843/files#r1882834039 - Added test cases with different scopes and different types to handle different cases of templated_fields well. --- .../execution_api/datamodels/taskinstance.py | 6 +- .../execution_api/routes/task_instances.py | 6 +- airflow/serialization/helpers.py | 5 ++ .../src/airflow/sdk/execution_time/comms.py | 2 +- .../airflow/sdk/execution_time/task_runner.py | 24 ++++---- .../tests/execution_time/test_task_runner.py | 58 +++++++++++++++++++ .../endpoints/test_task_instance_endpoint.py | 6 +- .../routes/public/test_task_instances.py | 6 +- .../routes/test_task_instances.py | 48 ++++++--------- tests/models/test_renderedtifields.py | 36 ++++++++++-- 10 files changed, 137 insertions(+), 60 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index e0d8f371f09d6..bbc557d012463 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -21,7 +21,7 @@ from datetime import timedelta from typing import Annotated, Any, Literal, Union -from pydantic import Discriminator, Field, RootModel, Tag, WithJsonSchema +from pydantic import Discriminator, Field, Tag, WithJsonSchema from airflow.api_fastapi.common.types import UtcDateTime from airflow.api_fastapi.core_api.base import BaseModel @@ -135,7 +135,3 @@ class TaskInstance(BaseModel): run_id: str try_number: int map_index: int | None = None - - -"""Schema for setting RTIF for a task instance.""" -RTIFPayload = RootModel[dict[str, str]] diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 90bbe1c1d3e5b..e06798209c5da 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -22,6 +22,7 @@ from uuid import UUID from fastapi import Body, HTTPException, status +from pydantic import JsonValue from sqlalchemy import update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select @@ -29,7 +30,6 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( - RTIFPayload, TIDeferredStatePayload, TIEnterRunningPayload, TIHeartbeatInfo, @@ -237,7 +237,7 @@ def ti_heartbeat( ) def ti_put_rtif( task_instance_id: UUID, - put_rtif_payload: RTIFPayload, + put_rtif_payload: Annotated[dict[str, JsonValue], Body()], session: SessionDep, ): """Add an RTIF entry for a task instance, sent by the worker.""" @@ -247,6 +247,6 @@ def ti_put_rtif( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, ) - _update_rtif(task_instance, put_rtif_payload.model_dump(), session) + _update_rtif(task_instance, put_rtif_payload, session) return {"message": "Rendered task instance fields successfully set"} diff --git a/airflow/serialization/helpers.py b/airflow/serialization/helpers.py index 85bf3a1cc551c..dc1aabbca986c 100644 --- a/airflow/serialization/helpers.py +++ b/airflow/serialization/helpers.py @@ -36,6 +36,11 @@ def serialize_template_field(template_field: Any, name: str) -> str | dict | lis def is_jsonable(x): try: json.dumps(x) + if isinstance(x, tuple): + # Tuple is converted to list in json.dumps + # so while it is jsonable, it changes the type which might be a surprise + # for the user, so instead we return False here -- which will convert it to string + return False except (TypeError, OverflowError): return False else: diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 34d6a9e3156d4..9e6093a092da0 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -176,7 +176,7 @@ class SetRenderedFields(BaseModel): # We are using a BaseModel here compared to server using RootModel because we # have a discriminator running with "type", and RootModel doesn't support type - rendered_fields: dict[str, str | None] + rendered_fields: dict[str, JsonValue] type: Literal["SetRenderedFields"] = "SetRenderedFields" diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index c01677ce1a798..5aca25f590e5e 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -27,7 +27,7 @@ import attrs import structlog -from pydantic import BaseModel, ConfigDict, TypeAdapter +from pydantic import BaseModel, ConfigDict, JsonValue, TypeAdapter from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState from airflow.sdk.definitions.baseoperator import BaseOperator @@ -196,22 +196,26 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]: # 1. Implementing the part where we pull in the logic to render fields and add that here # for all operators, we should do setattr(task, templated_field, rendered_templated_field) # task.templated_fields should give all the templated_fields and each of those fields should - # give the rendered values. + # give the rendered values. task.templated_fields should already be in a JSONable format and + # we should not have to handle that here. # 2. Once rendered, we call the `set_rtif` API to store the rtif in the metadata DB - templated_fields = ti.task.template_fields - payload = {} - - for field in templated_fields: - if field not in payload: - payload[field] = getattr(ti.task, field) # so that we do not call the API unnecessarily - if payload: - SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=payload)) + if rendered_fields := _get_rendered_fields(ti.task): + SUPERVISOR_COMMS.send_request(log=log, msg=SetRenderedFields(rendered_fields=rendered_fields)) return ti, log +def _get_rendered_fields(task: BaseOperator) -> dict[str, JsonValue]: + # TODO: Port one of the following to Task SDK + # airflow.serialization.helpers.serialize_template_field or + # airflow.models.renderedtifields.get_serialized_template_fields + from airflow.serialization.helpers import serialize_template_field + + return {field: serialize_template_field(getattr(task, field), field) for field in task.template_fields} + + def run(ti: RuntimeTaskInstance, log: Logger): """Run the task in this process.""" from airflow.exceptions import ( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 517157e0a7a90..c9755c252bbe6 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -260,3 +260,61 @@ def test_startup_basic_templated_dag(mocked_parse): ), log=mock.ANY, ) + + +@pytest.mark.parametrize( + ["task_params", "expected_rendered_fields"], + [ + pytest.param( + {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + id="no_templates", + ), + pytest.param( + { + "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}], + "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}}, + }, + { + "op_args": ["arg1", "arg2", 1, 2, 3.75, {"key": "value"}], + "op_kwargs": {"key1": "value1", "key2": 99.0, "key3": {"nested_key": "nested_value"}}, + }, + id="mixed_types", + ), + pytest.param( + {"my_tup": (1, 2), "my_set": {1, 2, 3}}, + {"my_tup": "(1, 2)", "my_set": "{1, 2, 3}"}, + id="tuples_and_sets", + ), + ], +) +def test_startup_dag_with_templated_fields(mocked_parse, task_params, expected_rendered_fields): + """Test startup of a DAG with various templated fields.""" + + class CustomOperator(BaseOperator): + template_fields = tuple(task_params.keys()) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + for key, value in task_params.items(): + setattr(self, key, value) + + task = CustomOperator(task_id="templated_task") + + what = StartupDetails( + ti=TaskInstance(id=uuid7(), task_id="templated_task", dag_id="basic_dag", run_id="c", try_number=1), + file="", + requests_fd=0, + ) + mocked_parse(what, "basic_dag", task) + + with mock.patch( + "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True + ) as mock_supervisor_comms: + mock_supervisor_comms.get_message.return_value = what + + startup() + mock_supervisor_comms.send_request.assert_called_once_with( + msg=SetRenderedFields(rendered_fields=expected_rendered_fields), + log=mock.ANY, + ) diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index f39cff8ae7650..a4089c9785a98 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -351,7 +351,7 @@ def test_should_respond_200_task_instance_with_rendered(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -403,7 +403,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -2371,7 +2371,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index 9b427253b2965..7ce944a4d47ae 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -344,7 +344,7 @@ def test_should_respond_200_task_instance_with_rendered(self, test_client, sessi "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -444,7 +444,7 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se "try_number": 0, "unixname": getuser(), "dag_run_id": "TEST_DAG_RUN_ID", - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, @@ -3070,7 +3070,7 @@ def test_set_note_should_respond_200_mapped_task_instance_with_rtif(self, test_c "try_number": 0, "unixname": getuser(), "dag_run_id": self.RUN_ID, - "rendered_fields": {"op_args": [], "op_kwargs": {}, "templates_dict": None}, + "rendered_fields": {"op_args": "()", "op_kwargs": {}, "templates_dict": None}, "rendered_map_index": None, "trigger": None, "triggerer_job": None, diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index c13effee0bb16..15e56bbc58710 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -422,16 +422,31 @@ def teardown_method(self): clear_db_runs() clear_rendered_ti_fields() - def test_ti_put_rtif_success(self, client, session, create_task_instance): + @pytest.mark.parametrize( + "payload", + [ + # string value + {"field1": "string_value", "field2": "another_string"}, + # dictionary value + {"field1": {"nested_key": "nested_value"}}, + # string lists value + {"field1": ["123"], "field2": ["a", "b", "c"]}, + # list of JSON values + {"field1": [1, "string", 3.14, True, None, {"nested": "dict"}]}, + # nested dictionary with mixed types in lists + { + "field1": {"nested_dict": {"key1": 123, "key2": "value"}}, + "field2": [3.14, {"sub_key": "sub_value"}, [1, 2]], + }, + ], + ) + def test_ti_put_rtif_success(self, client, session, create_task_instance, payload): ti = create_task_instance( task_id="test_ti_put_rtif_success", state=State.RUNNING, session=session, ) session.commit() - - payload = {"field1": "rendered_value1", "field2": "rendered_value2"} - response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) assert response.status_code == 201 assert response.json() == {"message": "Rendered task instance fields successfully set"} @@ -461,28 +476,3 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload) assert response.status_code == 404 assert response.json()["detail"] == "Not Found" - - def test_ti_put_rtif_extra_fields(self, client, session, create_task_instance): - ti = create_task_instance( - task_id="test_ti_put_rtif_missing_ti", - state=State.RUNNING, - session=session, - ) - session.commit() - - payload = { - "field1": "rendered_value1", - "field2": "rendered_value2", - "invalid_key": {"field3": "rendered_value3"}, - } - - response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) - assert response.status_code == 422 - assert response.json()["detail"] == [ - { - "input": {"field3": "rendered_value3"}, - "loc": ["body", "invalid_key"], - "msg": "Input should be a valid string", - "type": "string_type", - } - ] diff --git a/tests/models/test_renderedtifields.py b/tests/models/test_renderedtifields.py index 3f1b13cd1a35d..ded755c4d01d9 100644 --- a/tests/models/test_renderedtifields.py +++ b/tests/models/test_renderedtifields.py @@ -19,6 +19,7 @@ from __future__ import annotations +import ast import os from collections import Counter from datetime import date, timedelta @@ -100,8 +101,12 @@ def teardown_method(self): (None, None), ([], []), ({}, {}), + ((), "()"), + (set(), "set()"), ("test-string", "test-string"), ({"foo": "bar"}, {"foo": "bar"}), + (("foo", "bar"), "('foo', 'bar')"), + ({"foo", "bar"}, "{'foo', 'bar'}"), ("{{ task.task_id }}", "test"), (date(2018, 12, 6), "2018-12-06"), (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"), @@ -158,16 +163,35 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field, da assert ti.dag_id == rtif.dag_id assert ti.task_id == rtif.task_id assert ti.run_id == rtif.run_id - assert expected_rendered_field == rtif.rendered_fields.get("bash_command") + if type(templated_field) is set: + # the output order of a set is non-deterministic and can change per process. + # this validation can fail if that happens before stringification, so we convert to set and compare. + assert ast.literal_eval(expected_rendered_field) == ast.literal_eval( + rtif.rendered_fields.get("bash_command") + ) + else: + assert expected_rendered_field == rtif.rendered_fields.get("bash_command") session.add(rtif) session.flush() - assert RTIF.get_templated_fields(ti=ti, session=session) == { - "bash_command": expected_rendered_field, - "env": None, - "cwd": None, - } + if type(templated_field) is set: + # the output order of a set is non-deterministic and can change per process. + # this validation can fail if that happens before stringification, so we convert to set and compare. + expected = RTIF.get_templated_fields(ti=ti, session=session) + expected["bash_command"] = ast.literal_eval(expected["bash_command"]) + actual = { + "bash_command": ast.literal_eval(expected_rendered_field), + "env": None, + "cwd": None, + } + assert expected == actual + else: + assert RTIF.get_templated_fields(ti=ti, session=session) == { + "bash_command": expected_rendered_field, + "env": None, + "cwd": None, + } # Test the else part of get_templated_fields # i.e. for the TIs that are not stored in RTIF table # Fetching them will return None