Skip to content

Commit

Permalink
AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh authored Jan 30, 2025
1 parent 0637366 commit dc4ce65
Show file tree
Hide file tree
Showing 12 changed files with 266 additions and 29 deletions.
7 changes: 7 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,10 @@ class PrevSuccessfulDagRunResponse(BaseModel):
data_interval_end: UtcDateTime | None = None
start_date: UtcDateTime | None = None
end_date: UtcDateTime | None = None


class TIRuntimeCheckPayload(BaseModel):
"""Payload for performing Runtime checks on the TaskInstance model as requested by the SDK."""

inlets: list[AssetProfile] | None = None
outlets: list[AssetProfile] | None = None
42 changes: 42 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TIStateUpdate,
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
Expand Down Expand Up @@ -442,6 +444,46 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)


@router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
# TODO: Add Operation ID to control the function name in the OpenAPI spec
# TODO: Do we need to use create_openapi_http_exception_doc here?
responses={
status.HTTP_400_BAD_REQUEST: {"description": "Task Instance failed the runtime checks."},
status.HTTP_409_CONFLICT: {
"description": "Task Instance isn't in a running state. Cannot perform runtime checks."
},
status.HTTP_422_UNPROCESSABLE_ENTITY: {
"description": "Invalid payload for requested runtime checks on the Task Instance."
},
},
)
def ti_runtime_checks(
task_instance_id: UUID,
payload: TIRuntimeCheckPayload,
session: SessionDep,
):
ti_id_str = str(task_instance_id)
task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
if task_instance.state != State.RUNNING:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)

try:
TI.validate_inlet_outlet_assets_activeness(payload.inlets, payload.outlets, session) # type: ignore
except AirflowInactiveAssetInInletOrOutletException as e:
log.error("Task Instance %s fails the runtime checks.", ti_id_str)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "validation_failed",
"message": "Task Instance fails the runtime checks",
"error": str(e),
},
)


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
Expand Down
21 changes: 14 additions & 7 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def _run_raw_task(
context = ti.get_template_context(ignore_param_exceptions=False, session=session)

try:
ti._validate_inlet_outlet_assets_activeness(session=session)
if ti.task:
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session)
if not mark_success:
TaskInstance._execute_task_with_callbacks(
self=ti, # type: ignore[arg-type]
Expand Down Expand Up @@ -3715,16 +3718,20 @@ def duration_expression_update(
}
)

def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None:
if not self.task or not (self.task.outlets or self.task.inlets):
@staticmethod
def validate_inlet_outlet_assets_activeness(
inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session
) -> None:
if not (inlets or outlets):
return

all_asset_unique_keys = {
AssetUniqueKey.from_asset(inlet_or_outlet)
for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets)
if isinstance(inlet_or_outlet, Asset)
AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
for inlet_or_outlet in itertools.chain(inlets, outlets)
}
inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session)
inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys(
all_asset_unique_keys, session
)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys)

Expand Down
16 changes: 15 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
Expand All @@ -52,7 +53,7 @@
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, RuntimeCheckOnTask
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -177,6 +178,19 @@ def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunR
resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())

def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) -> OKResponse:
body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True))
try:
self.client.post(f"task-instances/{id}/runtime-checks", content=body.model_dump_json())
return OKResponse(ok=True)
except ServerResponseError as e:
if e.response.status_code == 400:
return OKResponse(ok=False)
elif e.response.status_code == 409:
# The TI isn't in the right state to perform the check, but we shouldn't fail the task for that
return OKResponse(ok=True)
raise


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ class TIRescheduleStatePayload(BaseModel):
end_date: Annotated[datetime, Field(title="End Date")]


class TIRuntimeCheckPayload(BaseModel):
"""
Payload for performing Runtime checks on the TaskInstance model as requested by the SDK.
"""

inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None
outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] = None


class TISuccessStatePayload(BaseModel):
"""
Schema for updating TaskInstance to success state.
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import attrs

from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.serialization.dag_dependency import DagDependency

if TYPE_CHECKING:
Expand Down Expand Up @@ -428,6 +429,14 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
dependency_id=self.name,
)

def asprofile(self) -> AssetProfile:
"""
Profiles Asset to AssetProfile.
:meta private:
"""
return AssetProfile(name=self.name or None, uri=self.uri or None, asset_type=Asset.__name__)


class AssetRef(BaseAsset, AttrsInstance):
"""
Expand Down
12 changes: 12 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TISuccessStatePayload,
VariableResponse,
XComResponse,
Expand Down Expand Up @@ -169,6 +170,11 @@ class ErrorResponse(BaseModel):
type: Literal["ErrorResponse"] = "ErrorResponse"


class OKResponse(BaseModel):
ok: bool
type: Literal["OKResponse"] = "OKResponse"


ToTask = Annotated[
Union[
AssetResult,
Expand All @@ -178,6 +184,7 @@ class ErrorResponse(BaseModel):
StartupDetails,
VariableResult,
XComResult,
OKResponse,
],
Field(discriminator="type"),
]
Expand Down Expand Up @@ -220,6 +227,10 @@ class RescheduleTask(TIRescheduleStatePayload):
type: Literal["RescheduleTask"] = "RescheduleTask"


class RuntimeCheckOnTask(TIRuntimeCheckPayload):
type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand Down Expand Up @@ -317,6 +328,7 @@ class GetPrevSuccessfulDagRun(BaseModel):
SetRenderedFields,
SetXCom,
TaskState,
RuntimeCheckOnTask,
],
Field(discriminator="type"),
]
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -767,6 +768,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, RuntimeCheckOnTask):
runtime_check_resp = self.client.task_instances.runtime_checks(id=self.id, msg=msg)
resp = runtime_check_resp.model_dump_json().encode()
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self.client.task_instances.succeed(
Expand Down
52 changes: 32 additions & 20 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
OKResponse,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -501,26 +503,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Get a real context object
ti.hostname = get_hostname()
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)

# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Update RTIF
# - Pre Execute
# etc
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
if ti.task.inlets or ti.task.outlets:
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore
msg = SUPERVISOR_COMMS.get_message() # type: ignore

if isinstance(msg, OKResponse) and not msg.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
else:
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
# TODO: Get things from _execute_task_with_callbacks
# - Pre Execute
# etc
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
except TaskDeferred as defer:
# TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
Expand Down
23 changes: 22 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from airflow.executors.workloads import BundleInfo
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
AssetResult,
ConnectionResult,
Expand All @@ -50,9 +50,11 @@
GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SucceedTask,
Expand Down Expand Up @@ -1011,6 +1013,25 @@ def watched_subprocess(self, mocker):
),
id="get_prev_successful_dagrun",
),
pytest.param(
RuntimeCheckOnTask(
inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
),
b'{"ok":true,"type":"OKResponse"}\n',
"task_instances.runtime_checks",
(),
{
"id": TI_ID,
"msg": RuntimeCheckOnTask(
inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore
outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore
type="RuntimeCheckOnTask",
),
},
OKResponse(ok=True),
id="runtime_check_on_task",
),
],
)
def test_handle_requests(
Expand Down
Loading

0 comments on commit dc4ce65

Please sign in to comment.