Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK #46020

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
14c4948
AIP-72: Port _validate_inlet_outlet_assets_activeness into Task SDK
amoghrajesh Jan 24, 2025
2aac70f
fixing tests
amoghrajesh Jan 26, 2025
ba8029e
adding tests for supervisor
amoghrajesh Jan 27, 2025
3781f5c
review comments ash
amoghrajesh Jan 27, 2025
6187c4c
fixing tests
amoghrajesh Jan 27, 2025
def0df8
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 27, 2025
9c7d7bc
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 28, 2025
835a04c
review part 1: changing the execution api
amoghrajesh Jan 28, 2025
b621a86
review part 2: changing comms
amoghrajesh Jan 28, 2025
5a6fe55
review comments on the task runner side
amoghrajesh Jan 28, 2025
bfd88b8
introducing asprofile on asste
amoghrajesh Jan 28, 2025
df21777
fixing tests
amoghrajesh Jan 28, 2025
9f262ff
fixing tests in CI
amoghrajesh Jan 28, 2025
e370faf
review comments ash
amoghrajesh Jan 28, 2025
a4cc173
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 28, 2025
c5c77fe
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
9ce6a13
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
3664b05
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
ab6cd46
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
618ace5
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 29, 2025
d1b828f
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 30, 2025
9559d6f
Merge branch 'main' into AIP72_validate_inlet_outlet_assets_activenes…
amoghrajesh Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,10 @@ class PrevSuccessfulDagRunResponse(BaseModel):
data_interval_end: UtcDateTime | None = None
start_date: UtcDateTime | None = None
end_date: UtcDateTime | None = None


class TIRuntimeCheckPayload(BaseModel):
"""Payload for performing Runtime checks on the TaskInstance model as requested by the SDK."""

inlet: list[AssetProfile] | None = None
outlet: list[AssetProfile] | None = None
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
46 changes: 45 additions & 1 deletion airflow/api_fastapi/execution_api/routes/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Annotated
from uuid import UUID

from fastapi import Body, HTTPException, status
from fastapi import Body, HTTPException, Response, status
from pydantic import JsonValue
from sqlalchemy import update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
Expand All @@ -37,10 +37,12 @@
TIHeartbeatInfo,
TIRescheduleStatePayload,
TIRunContext,
TIRuntimeCheckPayload,
TIStateUpdate,
TISuccessStatePayload,
TITerminalStatePayload,
)
from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI, _update_rtif
from airflow.models.taskreschedule import TaskReschedule
Expand Down Expand Up @@ -442,6 +444,48 @@ def get_previous_successful_dagrun(
return PrevSuccessfulDagRunResponse.model_validate(dag_run)


@router.post(
"/{task_instance_id}/runtime-checks",
status_code=status.HTTP_200_OK,
# TODO: Add description to the operation
# TODO: Add Operation ID to control the function name in the OpenAPI spec
# TODO: Do we need to use create_openapi_http_exception_doc here?
responses={
status.HTTP_400_BAD_REQUEST: {
"description": "Task Instance doesn't pass the required runtime checks"
},
status.HTTP_204_NO_CONTENT: {
"description": "Task Instance is not in a running state, cannot perform runtime checks."
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
},
status.HTTP_422_UNPROCESSABLE_ENTITY: {
"description": "Invalid payload for requested runtime checks on the Task Instance."
},
},
)
def ti_runtime_checks(
task_instance_id: UUID,
payload: TIRuntimeCheckPayload,
session: SessionDep,
):
ti_id_str = str(task_instance_id)
task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
if task_instance.state != State.RUNNING:
return Response(status_code=status.HTTP_204_NO_CONTENT)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
try:
TI.validate_inlet_outlet_assets_activeness(payload.inlet, payload.outlet, session) # type: ignore
except AirflowInactiveAssetInInletOrOutletException as e:
log.error("Task Instance %s doesn't pass the runtime checks.", ti_id_str)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "validation_failed",
"message": "Task Instance fails the runtime checks",
"error": str(e),
},
)
return {"message": "Runtime checks passed successfully."}


def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool:
"""Is task instance is eligible for retry."""
if state == State.RESTARTING:
Expand Down
29 changes: 22 additions & 7 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,18 @@ def _run_raw_task(
context = ti.get_template_context(ignore_param_exceptions=False, session=session)

try:
ti._validate_inlet_outlet_assets_activeness(session=session)
if ti.task:
inlets = [
AssetProfile(name=x.name or None, uri=x.uri or None, asset_type=type(x).__name__)
for x in ti.task.inlets
if isinstance(x, Asset)
]
outlets = [
AssetProfile(name=x.name or None, uri=x.uri or None, asset_type=type(x).__name__)
for x in ti.task.outlets
if isinstance(x, Asset)
]
TaskInstance.validate_inlet_outlet_assets_activeness(inlets, outlets, session=session)
if not mark_success:
TaskInstance._execute_task_with_callbacks(
self=ti, # type: ignore[arg-type]
Expand Down Expand Up @@ -3715,16 +3726,20 @@ def duration_expression_update(
}
)

def _validate_inlet_outlet_assets_activeness(self, session: Session) -> None:
if not self.task or not (self.task.outlets or self.task.inlets):
@staticmethod
def validate_inlet_outlet_assets_activeness(
inlets: list[AssetProfile], outlets: list[AssetProfile], session: Session
) -> None:
if not (inlets or outlets):
return

all_asset_unique_keys = {
AssetUniqueKey.from_asset(inlet_or_outlet)
for inlet_or_outlet in itertools.chain(self.task.inlets, self.task.outlets)
if isinstance(inlet_or_outlet, Asset)
AssetUniqueKey.from_asset(inlet_or_outlet) # type: ignore
for inlet_or_outlet in itertools.chain(inlets, outlets)
}
inactive_asset_unique_keys = self._get_inactive_asset_unique_keys(all_asset_unique_keys, session)
inactive_asset_unique_keys = TaskInstance._get_inactive_asset_unique_keys(
all_asset_unique_keys, session
)
if inactive_asset_unique_keys:
raise AirflowInactiveAssetInInletOrOutletException(inactive_asset_unique_keys)

Expand Down
13 changes: 12 additions & 1 deletion task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
from airflow.sdk import __version__
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
Expand All @@ -52,7 +53,7 @@
XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.sdk.execution_time.comms import ErrorResponse, OKResponse, RuntimeCheckOnTask
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser

Expand Down Expand Up @@ -177,6 +178,16 @@ def get_previous_successful_dagrun(self, id: uuid.UUID) -> PrevSuccessfulDagRunR
resp = self.client.get(f"task-instances/{id}/previous-successful-dagrun")
return PrevSuccessfulDagRunResponse.model_validate_json(resp.read())

def runtime_checks(self, id: uuid.UUID, msg: RuntimeCheckOnTask) -> OKResponse:
body = TIRuntimeCheckPayload(**msg.model_dump(exclude_unset=True))
try:
self.client.post(f"task-instances/{id}/runtime-checks", content=body.model_dump_json())
return OKResponse(ok=True)
except ServerResponseError as e:
if e.response.status_code == 400:
return OKResponse(ok=False)
raise


class ConnectionOperations:
__slots__ = ("client",)
Expand Down
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ class TIRescheduleStatePayload(BaseModel):
end_date: Annotated[datetime, Field(title="End Date")]


class TIRuntimeCheckPayload(BaseModel):
"""
Payload for performing Runtime checks on the TaskInstance model as requested by the SDK.
"""

inlet: Annotated[list[AssetProfile] | None, Field(title="Inlet")] = None
outlet: Annotated[list[AssetProfile] | None, Field(title="Outlet")] = None


class TISuccessStatePayload(BaseModel):
"""
Schema for updating TaskInstance to success state.
Expand Down
12 changes: 12 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from fastapi import Body
from pydantic import BaseModel, ConfigDict, Field, JsonValue

from airflow.api_fastapi.execution_api.datamodels.taskinstance import TIRuntimeCheckPayload
from airflow.sdk.api.datamodels._generated import (
AssetResponse,
BundleInfo,
Expand Down Expand Up @@ -169,6 +170,11 @@ class ErrorResponse(BaseModel):
type: Literal["ErrorResponse"] = "ErrorResponse"


class OKResponse(BaseModel):
ok: bool
type: Literal["OKResponse"] = "OKResponse"


ToTask = Annotated[
Union[
AssetResult,
Expand All @@ -178,6 +184,7 @@ class ErrorResponse(BaseModel):
StartupDetails,
VariableResult,
XComResult,
OKResponse,
],
Field(discriminator="type"),
]
Expand Down Expand Up @@ -220,6 +227,10 @@ class RescheduleTask(TIRescheduleStatePayload):
type: Literal["RescheduleTask"] = "RescheduleTask"


class RuntimeCheckOnTask(TIRuntimeCheckPayload):
type: Literal["RuntimeCheckOnTask"] = "RuntimeCheckOnTask"


class GetXCom(BaseModel):
key: str
dag_id: str
Expand Down Expand Up @@ -317,6 +328,7 @@ class GetPrevSuccessfulDagRun(BaseModel):
SetRenderedFields,
SetXCom,
TaskState,
RuntimeCheckOnTask,
],
Field(discriminator="type"),
]
9 changes: 9 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -767,6 +768,14 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
elif isinstance(msg, RuntimeCheckOnTask):
runtime_check_resp = self.client.task_instances.runtime_checks(id=self.id, msg=msg)
resp = runtime_check_resp.model_dump_json().encode()
if not runtime_check_resp.ok:
log.debug("Runtime checks failed on task %s, marking task as failed..", self.id)
self.client.task_instances.finish(
id=self.id, state=TerminalTIState.FAILED, when=datetime.now(tz=timezone.utc)
)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self.client.task_instances.succeed(
Expand Down
18 changes: 18 additions & 0 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
from airflow.sdk.execution_time.comms import (
DeferTask,
GetXCom,
OKResponse,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
StartupDetails,
Expand Down Expand Up @@ -501,6 +503,22 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Get a real context object
ti.hostname = get_hostname()
ti.task = ti.task.prepare_for_execution()
if ti.task.inlets or ti.task.outlets:
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
inlets = [
AssetProfile(name=x.name or None, uri=x.uri or None, asset_type=Asset.__name__)
for x in ti.task.inlets
if isinstance(x, Asset)
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
]
outlets = [
AssetProfile(name=x.name or None, uri=x.uri or None, asset_type=Asset.__name__)
for x in ti.task.outlets
if isinstance(x, Asset)
]
SUPERVISOR_COMMS.send_request(msg=RuntimeCheckOnTask(inlet=inlets, outlet=outlets), log=log) # type: ignore
msg = SUPERVISOR_COMMS.get_message() # type: ignore
if isinstance(msg, OKResponse) and not msg.ok:
log.info("Runtime checks failed for task, marking task as failed..")
return
amoghrajesh marked this conversation as resolved.
Show resolved Hide resolved
context = ti.get_template_context()
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
Expand Down
23 changes: 22 additions & 1 deletion task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
from airflow.executors.workloads import BundleInfo
from airflow.sdk.api import client as sdk_client
from airflow.sdk.api.client import ServerResponseError
from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState
from airflow.sdk.api.datamodels._generated import AssetProfile, TaskInstance, TerminalTIState
from airflow.sdk.definitions.asset import Asset
from airflow.sdk.execution_time.comms import (
AssetResult,
ConnectionResult,
Expand All @@ -50,9 +51,11 @@
GetPrevSuccessfulDagRun,
GetVariable,
GetXCom,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
RuntimeCheckOnTask,
SetRenderedFields,
SetXCom,
SucceedTask,
Expand Down Expand Up @@ -1011,6 +1014,24 @@ def watched_subprocess(self, mocker):
),
id="get_prev_successful_dagrun",
),
pytest.param(
RuntimeCheckOnTask(
inlet=[Asset(name="alias", uri="alias")], outlet=[Asset(name="alias", uri="alias")]
),
b'{"ok":true,"type":"OKResponse"}\n',
"task_instances.runtime_checks",
(),
{
"id": TI_ID,
"msg": RuntimeCheckOnTask(
inlet=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
outlet=[AssetProfile(name="alias", uri="alias", asset_type="asset")],
type="RuntimeCheckOnTask",
),
},
OKResponse(ok=True),
id="runtime_check_on_task",
),
],
)
def test_handle_requests(
Expand Down
39 changes: 39 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@
GetConnection,
GetVariable,
GetXCom,
OKResponse,
PrevSuccessfulDagRunResult,
RuntimeCheckOnTask,
SetRenderedFields,
StartupDetails,
SucceedTask,
Expand Down Expand Up @@ -651,6 +653,43 @@ def test_run_with_asset_outlets(
mock_supervisor_comms.send_request.assert_any_call(msg=expected_msg, log=mock.ANY)


def test_run_with_inlets_and_outlets(create_runtime_ti, mock_supervisor_comms):
"""Test running a basic tasks with inlets and outlets."""
from airflow.providers.standard.operators.bash import BashOperator

task = BashOperator(
outlets=[
Asset(name="name", uri="s3://bucket/my-task"),
Asset(name="new-name", uri="s3://bucket/my-task"),
],
inlets=[
Asset(name="name", uri="s3://bucket/my-task"),
Asset(name="new-name", uri="s3://bucket/my-task"),
],
task_id="inlets-and-outlets",
bash_command="echo 'hi'",
)

ti = create_runtime_ti(task=task, dag_id="dag_with_inlets_and_outlets")
mock_supervisor_comms.get_message.return_value = OKResponse(
ok=False,
)

run(ti, log=mock.MagicMock())

expected = RuntimeCheckOnTask(
inlet=[
AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"),
AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"),
],
outlet=[
AssetProfile(name="name", uri="s3://bucket/my-task", asset_type="Asset"),
AssetProfile(name="new-name", uri="s3://bucket/my-task", asset_type="Asset"),
],
)
mock_supervisor_comms.send_request.assert_called_with(msg=expected, log=mock.ANY)


class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_context):
"""Test get_template_context without ti_context_from_server."""
Expand Down
Loading
Loading