From 6e9f81380c915706928e9448c046b8aa59c2db97 Mon Sep 17 00:00:00 2001 From: Amogh Desai Date: Thu, 30 Jan 2025 12:04:39 +0530 Subject: [PATCH] AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK (#46020) --- .../execution_api/datamodels/taskinstance.py | 7 +++ .../execution_api/routes/task_instances.py | 42 +++++++++++++ airflow/models/taskinstance.py | 21 ++++--- task_sdk/src/airflow/sdk/api/client.py | 16 ++++- .../airflow/sdk/api/datamodels/_generated.py | 9 +++ .../airflow/sdk/definitions/asset/__init__.py | 9 +++ .../src/airflow/sdk/execution_time/comms.py | 12 ++++ .../airflow/sdk/execution_time/supervisor.py | 4 ++ .../airflow/sdk/execution_time/task_runner.py | 52 ++++++++++------ .../tests/execution_time/test_supervisor.py | 23 ++++++- .../tests/execution_time/test_task_runner.py | 39 ++++++++++++ .../routes/test_task_instances.py | 61 +++++++++++++++++++ 12 files changed, 266 insertions(+), 29 deletions(-) diff --git a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index 6cc82259cf758b..e427cac5f3db48 100644 --- a/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -249,3 +249,10 @@ class PrevSuccessfulDagRunResponse(BaseModel): data_interval_end: UtcDateTime | None = None start_date: UtcDateTime | None = None end_date: UtcDateTime | None = None + + +class TIRuntimeCheckPayload(BaseModel): + """Payload for performing Runtime checks on the TaskInstance model as requested by the SDK.""" + + inlets: list[AssetProfile] | None = None + outlets: list[AssetProfile] | None = None diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index 899017e612d5a2..155f96f861ab5a 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -37,10 +37,12 @@ TIHeartbeatInfo, TIRescheduleStatePayload, TIRunContext, + TIRuntimeCheckPayload, TIStateUpdate, TISuccessStatePayload, TITerminalStatePayload, ) +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException from airflow.models.dagrun import DagRun as DR from airflow.models.taskinstance import TaskInstance as TI, _update_rtif from airflow.models.taskreschedule import TaskReschedule @@ -442,6 +444,46 @@ def get_previous_successful_dagrun( return PrevSuccessfulDagRunResponse.model_validate(dag_run) +@router.post( + "/{task_instance_id}/runtime-checks", + status_code=status.HTTP_204_NO_CONTENT, + # TODO: Add description to the operation + # TODO: Add Operation ID to control the function name in the OpenAPI spec + # TODO: Do we need to use create_openapi_http_exception_doc here? + responses={ + status.HTTP_400_BAD_REQUEST: {"description": "Task Instance failed the runtime checks."}, + status.HTTP_409_CONFLICT: { + "description": "Task Instance isn't in a running state. Cannot perform runtime checks." + }, + status.HTTP_422_UNPROCESSABLE_ENTITY: { + "description": "Invalid payload for requested runtime checks on the Task Instance." + }, + }, +) +def ti_runtime_checks( + task_instance_id: UUID, + payload: TIRuntimeCheckPayload, + session: SessionDep, +): + ti_id_str = str(task_instance_id) + task_instance = session.scalar(select(TI).where(TI.id == ti_id_str)) + if task_instance.state != State.RUNNING: + raise HTTPException(status_code=status.HTTP_409_CONFLICT) + + try: + TI.validate_inlet_outlet_assets_activeness(payload.inlets, payload.outlets, session) # type: ignore + except AirflowInactiveAssetInInletOrOutletException as e: + log.error("Task Instance %s fails the runtime checks.", ti_id_str) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "reason": "validation_failed", + "message": "Task Instance fails the runtime checks", + "error": str(e), + }, + ) + + def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: """Is task instance is eligible for retry.""" if state == State.RESTARTING: diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index f1aa3a8236e9c5..69b6d147eadb3e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -260,7 +260,10 @@ def _run_raw_task( context = ti.get_template_context(ignore_param_exceptions=False, session=session) try: - ti._validate_inlet_outlet_assets_activeness(session=session) + if ti.task: + inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] + outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] + TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session) if not mark_success: TaskInstance._execute_task_with_callbacks( self=ti, # type: ignore[arg-type] @@ -3715,16 +3718,20 @@ def duration_expression_update( } ) - def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None: - if not self.task or not (self.task.outlets or self.task.inlets): + @staticmethod + def validate_inlet_outlet_assets_activeness( + inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session + ) -> None: + if not (inlets or outlets): return all_asset_unique_keys = { - AssetUniqueKey.from_asset(inlet_or_outlet) - for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets) - if isinstance(inlet_or_outlet, Asset) + AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore + for inlet_or_outlet in itertools.chain(inlets, outlets) } - inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session) + inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys( + all_asset_unique_keys, session + ) if inactive_asset_unique_keys: raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys) diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 443256e3a67eea..821e589ad522f5 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -32,6 +32,7 @@ from tenacity import before_log, wait_random_exponential from uuid6 import uuid7 +from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload from airflow.sdk import __version__ from airflow.sdk.api.datamodels._generated import ( AssetResponse, @@ -52,7 +53,7 @@ XComResponse, ) from airflow.sdk.exceptions import ErrorType -from airflow.sdk.execution_time.comms import ErrorResponse +from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, RuntimeCheckOnTask from airflow.utils.net import get_hostname from airflow.utils.platform import getuser @@ -177,6 +178,19 @@ def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunR resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun") return PrevSuccessfulDagRunResponse.model_validate_json(resp.read()) + def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) -> OKResponse: + body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True)) + try: + self.client.post(f"task-instances/{id}/runtime-checks", content=body.model_dump_json()) + return OKResponse(ok=True) + except ServerResponseError as e: + if e.response.status_code == 400: + return OKResponse(ok=False) + elif e.response.status_code == 409: + # The TI isn't in the right state to perform the check, but we shouldn't fail the task for that + return OKResponse(ok=True) + raise + class ConnectionOperations: __slots__ = ("client",) diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py index 3383e61d1c395a..1d6d0eb4156c3b 100644 --- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -148,6 +148,15 @@ class TIRescheduleStatePayload(BaseModel): end_date: Annotated[datetime, Field(title="End Date")] +class TIRuntimeCheckPayload(BaseModel): + """ + Payload for performing Runtime checks on the TaskInstance model as requested by the SDK. + """ + + inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None + outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] = None + + class TISuccessStatePayload(BaseModel): """ Schema for updating TaskInstance to success state. diff --git a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py index 91ebc4ab6cb8ba..b976bb8c1563f2 100644 --- a/task_sdk/src/airflow/sdk/definitions/asset/__init__.py +++ b/task_sdk/src/airflow/sdk/definitions/asset/__init__.py @@ -27,6 +27,7 @@ import attrs +from airflow.sdk.api.datamodels._generated import AssetProfile from airflow.serialization.dag_dependency import DagDependency if TYPE_CHECKING: @@ -428,6 +429,14 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe dependency_id=self.name, ) + def asprofile(self) -> AssetProfile: + """ + Profiles Asset to AssetProfile. + + :meta private: + """ + return AssetProfile(name=self.name or None, uri=self.uri or None, asset_type=Asset.__name__) + class AssetRef(BaseAsset, AttrsInstance): """ diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index 3ab8addc8bbf8c..93ea133f4895fb 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -60,6 +60,7 @@ TIDeferredStatePayload, TIRescheduleStatePayload, TIRunContext, + TIRuntimeCheckPayload, TISuccessStatePayload, VariableResponse, XComResponse, @@ -169,6 +170,11 @@ class ErrorResponse(BaseModel): type: Literal["ErrorResponse"] = "ErrorResponse" +class OKResponse(BaseModel): + ok: bool + type: Literal["OKResponse"] = "OKResponse" + + ToTask = Annotated[ Union[ AssetResult, @@ -178,6 +184,7 @@ class ErrorResponse(BaseModel): StartupDetails, VariableResult, XComResult, + OKResponse, ], Field(discriminator="type"), ] @@ -220,6 +227,10 @@ class RescheduleTask(TIRescheduleStatePayload): type: Literal["RescheduleTask"] = "RescheduleTask" +class RuntimeCheckOnTask(TIRuntimeCheckPayload): + type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask" + + class GetXCom(BaseModel): key: str dag_id: str @@ -317,6 +328,7 @@ class GetPrevSuccessfulDagRun(BaseModel): SetRenderedFields, SetXCom, TaskState, + RuntimeCheckOnTask, ], Field(discriminator="type"), ] diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 569855016cfe6b..30050c0b955191 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -73,6 +73,7 @@ PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + RuntimeCheckOnTask, SetRenderedFields, SetXCom, StartupDetails, @@ -767,6 +768,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): if isinstance(msg, TaskState): self._terminal_state = msg.state self._task_end_time_monotonic = time.monotonic() + elif isinstance(msg, RuntimeCheckOnTask): + runtime_check_resp = self.client.task_instances.runtime_checks(id=self.id, msg=msg) + resp = runtime_check_resp.model_dump_json().encode() elif isinstance(msg, SucceedTask): self._terminal_state = msg.state self.client.task_instances.succeed( diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index c2d2c51b630be2..715dbf75dd77de 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -40,7 +40,9 @@ from airflow.sdk.execution_time.comms import ( DeferTask, GetXCom, + OKResponse, RescheduleTask, + RuntimeCheckOnTask, SetRenderedFields, SetXCom, StartupDetails, @@ -501,26 +503,36 @@ def run(ti: RuntimeTaskInstance, log: Logger): # TODO: Get a real context object ti.hostname = get_hostname() ti.task = ti.task.prepare_for_execution() - context = ti.get_template_context() - with set_current_context(context): - jinja_env = ti.task.dag.get_template_env() - ti.task = ti.render_templates(context=context, jinja_env=jinja_env) - result = _execute_task(context, ti.task) - - _push_xcom_if_needed(result, ti) - - task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) - - # TODO: Get things from _execute_task_with_callbacks - # - Clearing XCom - # - Update RTIF - # - Pre Execute - # etc - msg = SucceedTask( - end_date=datetime.now(tz=timezone.utc), - task_outlets=task_outlets, - outlet_events=outlet_events, - ) + if ti.task.inlets or ti.task.outlets: + inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)] + outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)] + SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore + msg = SUPERVISOR_COMMS.get_message() # type: ignore + + if isinstance(msg, OKResponse) and not msg.ok: + log.info("Runtime checks failed for task, marking task as failed..") + msg = TaskState( + state=TerminalTIState.FAILED, + end_date=datetime.now(tz=timezone.utc), + ) + else: + context = ti.get_template_context() + with set_current_context(context): + jinja_env = ti.task.dag.get_template_env() + ti.task = ti.render_templates(context=context, jinja_env=jinja_env) + # TODO: Get things from _execute_task_with_callbacks + # - Pre Execute + # etc + result = _execute_task(context, ti.task) + + _push_xcom_if_needed(result, ti) + + task_outlets, outlet_events = _process_outlets(context, ti.task.outlets) + msg = SucceedTask( + end_date=datetime.now(tz=timezone.utc), + task_outlets=task_outlets, + outlet_events=outlet_events, + ) except TaskDeferred as defer: # TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id? log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id) diff --git a/task_sdk/tests/execution_time/test_supervisor.py b/task_sdk/tests/execution_time/test_supervisor.py index 2bd631fc0803cc..4bc8febc67c14e 100644 --- a/task_sdk/tests/execution_time/test_supervisor.py +++ b/task_sdk/tests/execution_time/test_supervisor.py @@ -39,7 +39,7 @@ from airflow.executors.workloads import BundleInfo from airflow.sdk.api import client as sdk_client from airflow.sdk.api.client import ServerResponseError -from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState +from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState from airflow.sdk.execution_time.comms import ( AssetResult, ConnectionResult, @@ -50,9 +50,11 @@ GetPrevSuccessfulDagRun, GetVariable, GetXCom, + OKResponse, PrevSuccessfulDagRunResult, PutVariable, RescheduleTask, + RuntimeCheckOnTask, SetRenderedFields, SetXCom, SucceedTask, @@ -1011,6 +1013,25 @@ def watched_subprocess(self, mocker): ), id="get_prev_successful_dagrun", ), + pytest.param( + RuntimeCheckOnTask( + inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], + outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], + ), + b'{"ok":true,"type":"OKResponse"}\n', + "task_instances.runtime_checks", + (), + { + "id": TI_ID, + "msg": RuntimeCheckOnTask( + inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore + outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore + type="RuntimeCheckOnTask", + ), + }, + OKResponse(ok=True), + id="runtime_check_on_task", + ), ], ) def test_handle_requests( diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index 250b7e765a0cf5..d9aa675242cd08 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -47,7 +47,9 @@ GetConnection, GetVariable, GetXCom, + OKResponse, PrevSuccessfulDagRunResult, + RuntimeCheckOnTask, SetRenderedFields, StartupDetails, SucceedTask, @@ -651,6 +653,43 @@ def test_run_with_asset_outlets( mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY) +def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms): + """Test running a basic tasks with inlets and outlets.""" + from airflow.providers.standard.operators.bash import BashOperator + + task = BashOperator( + outlets=[ + Asset(name="name", uri="s3://bucket/my-task"), + Asset(name="new-name", uri="s3://bucket/my-task"), + ], + inlets=[ + Asset(name="name", uri="s3://bucket/my-task"), + Asset(name="new-name", uri="s3://bucket/my-task"), + ], + task_id="inlets-and-outlets", + bash_command="echo 'hi'", + ) + + ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets") + mock_supervisor_comms.get_message.return_value = OKResponse( + ok=True, + ) + + run(ti, log=mock.MagicMock()) + + expected = RuntimeCheckOnTask( + inlets=[ + AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"), + AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"), + ], + outlets=[ + AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"), + AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"), + ], + ) + mock_supervisor_comms.send_request.assert_any_call(msg=expected, log=mock.ANY) + + class TestRuntimeTaskInstance: def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context): """Test get_template_context without ti_context_from_server.""" diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 9ccd1b0d088a66..8a7b21699f25e8 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -25,9 +25,11 @@ from sqlalchemy import select, update from sqlalchemy.exc import SQLAlchemyError +from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel from airflow.models.taskinstance import TaskInstance +from airflow.sdk.definitions.asset import AssetUniqueKey from airflow.utils import timezone from airflow.utils.state import State, TaskInstanceState, TerminalTIState @@ -655,6 +657,65 @@ def test_ti_update_state_to_failed_without_retry_table_check(self, client, sessi assert ti.next_kwargs is None assert ti.duration == 3600.00 + @pytest.mark.parametrize( + ("state", "expected_status_code"), + [ + (State.RUNNING, 204), + (State.SUCCESS, 409), + (State.QUEUED, 409), + (State.FAILED, 409), + ], + ) + def test_ti_runtime_checks_success( + self, client, session, create_task_instance, state, expected_status_code + ): + ti = create_task_instance( + task_id="test_ti_runtime_checks", + state=state, + ) + session.commit() + + with mock.patch( + "airflow.models.taskinstance.TaskInstance.validate_inlet_outlet_assets_activeness" + ) as mock_validate_inlet_outlet_assets_activeness: + mock_validate_inlet_outlet_assets_activeness.return_value = None + response = client.post( + f"/execution/task-instances/{ti.id}/runtime-checks", + json={ + "inlets": [], + "outlets": [], + }, + ) + + assert response.status_code == expected_status_code + + session.expire_all() + + def test_ti_runtime_checks_failure(self, client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_runtime_checks_failure", + state=State.RUNNING, + ) + session.commit() + + with mock.patch( + "airflow.models.taskinstance.TaskInstance.validate_inlet_outlet_assets_activeness" + ) as mock_validate_inlet_outlet_assets_activeness: + mock_validate_inlet_outlet_assets_activeness.side_effect = ( + AirflowInactiveAssetInInletOrOutletException([AssetUniqueKey(name="abc", uri="something")]) + ) + response = client.post( + f"/execution/task-instances/{ti.id}/runtime-checks", + json={ + "inlets": [], + "outlets": [], + }, + ) + + assert response.status_code == 400 + + session.expire_all() + class TestTIHealthEndpoint: def setup_method(self):