Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK #46020

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
14c4948
AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK
amoghrajesh Jan 24, 2025
2aac70f
fixing tests
amoghrajesh Jan 26, 2025
ba8029e
adding tests for supervisor
amoghrajesh Jan 27, 2025
3781f5c
review comments ash
amoghrajesh Jan 27, 2025
6187c4c
fixing tests
amoghrajesh Jan 27, 2025
def0df8
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 27, 2025
9c7d7bc
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 28, 2025
835a04c
review part 1: changing the execution api
amoghrajesh Jan 28, 2025
b621a86
review part 2: changing comms
amoghrajesh Jan 28, 2025
5a6fe55
review comments on the task runner side
amoghrajesh Jan 28, 2025
bfd88b8
introducing asprofile on asste
amoghrajesh Jan 28, 2025
df21777
fixing tests
amoghrajesh Jan 28, 2025
9f262ff
fixing tests in CI
amoghrajesh Jan 28, 2025
e370faf
review comments ash
amoghrajesh Jan 28, 2025
a4cc173
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 28, 2025
c5c77fe
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
9ce6a13
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
3664b05
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
ab6cd46
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
618ace5
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
d1b828f
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 30, 2025
9559d6f
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,10 @@ class PrevSuccessfulDagRunResponse(BaseModel):
data_interval_end: UtcDateTime | None = None
start_date: UtcDateTime | None = None
end_date: UtcDateTime | None = None


class TIRuntimeCheckPayload(BaseModel):
"""Payload for performing Runtime checks on the TaskInstance model as requested by the SDK."""

inlets: list[AssetProfile] | None = None
outlets: list[AssetProfile] | None = None
42 changes: 42 additions & 0 deletions airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TIStateUpdate,
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
Expand Down Expand Up @@ -442,6 +444,46 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)


@router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
# TODO: Add Operation ID to control the function name in the OpenAPI spec
# TODO: Do we need to use create_openapi_http_exception_doc here?
responses={
status.HTTP_400_BAD_REQUEST: {"description": "Task Instance failed the runtime checks."},
status.HTTP_409_CONFLICT: {
"description": "Task Instance isn't in a running state. Cannot perform runtime checks."
},
status.HTTP_422_UNPROCESSABLE_ENTITY: {
"description": "Invalid payload for requested runtime checks on the Task Instance."
},
},
)
def ti_runtime_checks(
task_instance_id: UUID,
payload: TIRuntimeCheckPayload,
session: SessionDep,
):
ti_id_str = str(task_instance_id)
task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
if task_instance.state != State.RUNNING:
raise HTTPException(status_code=status.HTTP_409_CONFLICT)

try:
TI.validate_inlet_outlet_assets_activeness(payload.inlets, payload.outlets, session) # type: ignore
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
except AirflowInactiveAssetInInletOrOutletException as e:
log.error("Task Instance %s fails the runtime checks.", ti_id_str)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "validation_failed",
"message": "Task Instance fails the runtime checks",
"error": str(e),
},
)


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
Expand Down
21 changes: 14 additions & 7 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,10 @@ def _run_raw_task(
context = ti.get_template_context(ignore_param_exceptions=False, session=session)

try:
ti._validate_inlet_outlet_assets_activeness(session=session)
if ti.task:
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session)
if not mark_success:
TaskInstance._execute_task_with_callbacks(
self=ti, # type: ignore[arg-type]
Expand Down Expand Up @@ -3715,16 +3718,20 @@ def duration_expression_update(
}
)

def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None:
if not self.task or not (self.task.outlets or self.task.inlets):
@staticmethod
def validate_inlet_outlet_assets_activeness(
inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session
) -> None:
if not (inlets or outlets):
return

all_asset_unique_keys = {
AssetUniqueKey.from_asset(inlet_or_outlet)
for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets)
if isinstance(inlet_or_outlet, Asset)
AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
for inlet_or_outlet in itertools.chain(inlets, outlets)
}
inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session)
inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys(
all_asset_unique_keys, session
)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys)

Expand Down
16 changes: 15 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
Expand All @@ -52,7 +53,7 @@
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, RuntimeCheckOnTask
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -177,6 +178,19 @@ def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunR
resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())

def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) -> OKResponse:
body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True))
try:
self.client.post(f"task-instances/{id}/runtime-checks", content=body.model_dump_json())
return OKResponse(ok=True)
except ServerResponseError as e:
if e.response.status_code == 400:
return OKResponse(ok=False)
elif e.response.status_code == 409:
# The TI isn't in the right state to perform the check, but we shouldn't fail the task for that
return OKResponse(ok=True)
raise


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ class TIRescheduleStatePayload(BaseModel):
end_date: Annotated[datetime, Field(title="End Date")]


class TIRuntimeCheckPayload(BaseModel):
"""
Payload for performing Runtime checks on the TaskInstance model as requested by the SDK.
"""

inlets: Annotated[list[AssetProfile] | None, Field(title="Inlets")] = None
outlets: Annotated[list[AssetProfile] | None, Field(title="Outlets")] = None


class TISuccessStatePayload(BaseModel):
"""
Schema for updating TaskInstance to success state.
Expand Down
5 changes: 5 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import attrs

from airflow.sdk.api.datamodels._generated import AssetProfile
from airflow.serialization.dag_dependency import DagDependency

if TYPE_CHECKING:
Expand Down Expand Up @@ -428,6 +429,10 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
dependency_id=self.name,
)

def asprofile(self) -> AssetProfile:
"""Profiles Asset to AssetProfile."""
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
return AssetProfile(name=self.name or None, uri=self.uri or None, asset_type=Asset.__name__)


class AssetRef(BaseAsset, AttrsInstance):
"""
Expand Down
12 changes: 12 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
TIDeferredStatePayload,
TIRescheduleStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TISuccessStatePayload,
VariableResponse,
XComResponse,
Expand Down Expand Up @@ -169,6 +170,11 @@ class ErrorResponse(BaseModel):
type: Literal["ErrorResponse"] = "ErrorResponse"


class OKResponse(BaseModel):
ok: bool
type: Literal["OKResponse"] = "OKResponse"


ToTask = Annotated[
Union[
AssetResult,
Expand All @@ -178,6 +184,7 @@ class ErrorResponse(BaseModel):
StartupDetails,
VariableResult,
XComResult,
OKResponse,
],
Field(discriminator="type"),
]
Expand Down Expand Up @@ -220,6 +227,10 @@ class RescheduleTask(TIRescheduleStatePayload):
type: Literal["RescheduleTask"] = "RescheduleTask"


class RuntimeCheckOnTask(TIRuntimeCheckPayload):
type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand Down Expand Up @@ -317,6 +328,7 @@ class GetPrevSuccessfulDagRun(BaseModel):
SetRenderedFields,
SetXCom,
TaskState,
RuntimeCheckOnTask,
],
Field(discriminator="type"),
]
4 changes: 4 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -767,6 +768,9 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, RuntimeCheckOnTask):
runtime_check_resp = self.client.task_instances.runtime_checks(id=self.id, msg=msg)
resp = runtime_check_resp.model_dump_json().encode()
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self.client.task_instances.succeed(
Expand Down
55 changes: 35 additions & 20 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
OKResponse,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -501,26 +503,39 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Get a real context object
ti.hostname = get_hostname()
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)

# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Update RTIF
# - Pre Execute
# etc
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
if ti.task.inlets or ti.task.outlets:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
inlets = [asset.asprofile() for asset in ti.task.inlets if isinstance(asset, Asset)]
outlets = [asset.asprofile() for asset in ti.task.outlets if isinstance(asset, Asset)]
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlets=inlets, outlets=outlets), log=log) # type: ignore
msg = SUPERVISOR_COMMS.get_message() # type: ignore

if isinstance(msg, OKResponse) and not msg.ok:
log.info("Runtime checks failed for task, marking task as failed..")
msg = TaskState(
state=TerminalTIState.FAILED,
end_date=datetime.now(tz=timezone.utc),
)
else:
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

task_outlets, outlet_events = _process_outlets(context, ti.task.outlets)

# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Update RTIF
# - Pre Execute
# etc
ashb marked this conversation as resolved.
Show resolved Hide resolved
msg = SucceedTask(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
)
except TaskDeferred as defer:
# TODO: Should we use structlog.bind_contextvars here for dag_id, task_id & run_id?
log.info("Pausing task as DEFERRED. ", dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id)
Expand Down
23 changes: 22 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from airflow.executors.workloads import BundleInfo
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState
from airflow.sdk.execution_time.comms import (
AssetResult,
ConnectionResult,
Expand All @@ -50,9 +50,11 @@
GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SucceedTask,
Expand Down Expand Up @@ -1011,6 +1013,25 @@ def watched_subprocess(self, mocker):
),
id="get_prev_successful_dagrun",
),
pytest.param(
RuntimeCheckOnTask(
inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
),
b'{"ok":true,"type":"OKResponse"}\n',
"task_instances.runtime_checks",
(),
{
"id": TI_ID,
"msg": RuntimeCheckOnTask(
inlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore
outlets=[AssetProfile(name="alias", uri="alias", asset_type="asset")], # type: ignore
type="RuntimeCheckOnTask",
),
},
OKResponse(ok=True),
id="runtime_check_on_task",
),
],
)
def test_handle_requests(
Expand Down
Loading