Skip to content

Commit

Permalink
Add 'airflow assets materialize' (again) (apache#44603)
Browse files Browse the repository at this point in the history
* Add 'airflow assets materialize' (apache#44558)

* Correctly handle session for object in exception
  • Loading branch information
uranusjr authored Dec 3, 2024
1 parent edf3e33 commit a537d9c
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 9 deletions.
17 changes: 12 additions & 5 deletions airflow/api/common/trigger_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sqlalchemy.orm.session import Session


@provide_session
def _trigger_dag(
dag_id: str,
dag_bag: DagBag,
Expand All @@ -45,6 +46,7 @@ def _trigger_dag(
conf: dict | str | None = None,
logical_date: datetime | None = None,
replace_microseconds: bool = True,
session: Session = NEW_SESSION,
) -> DagRun | None:
"""
Triggers DAG run.
Expand All @@ -58,7 +60,7 @@ def _trigger_dag(
:param replace_microseconds: whether microseconds should be zeroed
:return: list of triggered dags
"""
dag = dag_bag.get_dag(dag_id) # prefetch dag if it is stored serialized
dag = dag_bag.get_dag(dag_id, session=session) # prefetch dag if it is stored serialized

if dag is None or dag_id not in dag_bag.dags:
raise DagNotFound(f"Dag id {dag_id} not found")
Expand All @@ -84,15 +86,18 @@ def _trigger_dag(
run_id = run_id or dag.timetable.generate_run_id(
run_type=DagRunType.MANUAL, logical_date=coerced_logical_date, data_interval=data_interval
)
dag_run = DagRun.find_duplicate(dag_id=dag_id, run_id=run_id)

if dag_run:
# This intentionally does not use 'session' in the current scope because it
# may be rolled back when this function exits with an exception (due to how
# provide_session is implemented). This would make the DagRun object in the
# DagRunAlreadyExists expire and unusable.
if dag_run := DagRun.find_duplicate(dag_id=dag_id, run_id=run_id):
raise DagRunAlreadyExists(dag_run)

run_conf = None
if conf:
run_conf = conf if isinstance(conf, dict) else json.loads(conf)
dag_version = DagVersion.get_latest_version(dag.dag_id)
dag_version = DagVersion.get_latest_version(dag.dag_id, session=session)
dag_run = dag.create_dagrun(
run_id=run_id,
logical_date=logical_date,
Expand All @@ -102,6 +107,7 @@ def _trigger_dag(
dag_version=dag_version,
data_interval=data_interval,
triggered_by=triggered_by,
session=session,
)

return dag_run
Expand Down Expand Up @@ -130,7 +136,7 @@ def trigger_dag(
:param session: Unused. Only added in compatibility with database isolation mode
:return: first dag run triggered - even if more than one Dag Runs were triggered or None
"""
dag_model = DagModel.get_current(dag_id)
dag_model = DagModel.get_current(dag_id, session=session)
if dag_model is None:
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")

Expand All @@ -143,6 +149,7 @@ def trigger_dag(
logical_date=logical_date,
replace_microseconds=replace_microseconds,
triggered_by=triggered_by,
session=session,
)

return dr if dr else None
10 changes: 8 additions & 2 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,8 +935,8 @@ def string_lower_type(val):
default=("name", "uri", "group", "extra"),
)

ARG_ASSET_NAME = Arg(("--name",), help="Asset name")
ARG_ASSET_URI = Arg(("--uri",), help="Asset URI")
ARG_ASSET_NAME = Arg(("--name",), default="", help="Asset name")
ARG_ASSET_URI = Arg(("--uri",), default="", help="Asset URI")

ALTERNATIVE_CONN_SPECS_ARGS = [
ARG_CONN_TYPE,
Expand Down Expand Up @@ -986,6 +986,12 @@ class GroupCommand(NamedTuple):
func=lazy_load_command("airflow.cli.commands.asset_command.asset_details"),
args=(ARG_ASSET_NAME, ARG_ASSET_URI, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
name="materialize",
help="Materialize an asset",
func=lazy_load_command("airflow.cli.commands.asset_command.asset_materialize"),
args=(ARG_ASSET_NAME, ARG_ASSET_URI, ARG_OUTPUT, ARG_VERBOSE),
),
)
BACKFILL_COMMANDS = (
ActionCommand(
Expand Down
42 changes: 41 additions & 1 deletion airflow/cli/commands/asset_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@

from sqlalchemy import select

from airflow.api.common.trigger_dag import trigger_dag
from airflow.api_fastapi.core_api.datamodels.assets import AssetResponse
from airflow.api_fastapi.core_api.datamodels.dag_run import DAGRunResponse
from airflow.cli.simple_table import AirflowConsole
from airflow.models.asset import AssetModel
from airflow.models.asset import AssetModel, TaskOutletAssetReference
from airflow.utils import cli as cli_utils
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.types import DagRunTriggeredByType

if typing.TYPE_CHECKING:
from typing import Any
Expand Down Expand Up @@ -83,3 +86,40 @@ def asset_details(args, *, session: Session = NEW_SESSION) -> None:
data = [model_data]

AirflowConsole().print_as(data=data, output=args.output)


@cli_utils.action_cli
@provide_session
def asset_materialize(args, *, session: Session = NEW_SESSION) -> None:
"""
Materialize the specified asset.
This is done by finding the DAG with the asset defined as outlet, and create
a run for that DAG.
"""
if not args.name and not args.uri:
raise SystemExit("Either --name or --uri is required")

stmt = select(TaskOutletAssetReference.dag_id).join(TaskOutletAssetReference.asset)
select_message_parts = []
if args.name:
stmt = stmt.where(AssetModel.name == args.name)
select_message_parts.append(f"name {args.name}")
if args.uri:
stmt = stmt.where(AssetModel.uri == args.uri)
select_message_parts.append(f"URI {args.uri}")
dag_id_it = iter(session.scalars(stmt.group_by(TaskOutletAssetReference.dag_id).limit(2)))
select_message = " and ".join(select_message_parts)

if (dag_id := next(dag_id_it, None)) is None:
raise SystemExit(f"Asset with {select_message} does not exist.")
if next(dag_id_it, None) is not None:
raise SystemExit(f"More than one DAG materializes asset with {select_message}.")

dagrun = trigger_dag(dag_id=dag_id, triggered_by=DagRunTriggeredByType.CLI, session=session)
if dagrun is not None:
data = [DAGRunResponse.model_validate(dagrun).model_dump()]
else:
data = []

AirflowConsole().print_as(data=data, output=args.output)
39 changes: 38 additions & 1 deletion tests/cli/commands/test_asset_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from airflow.models.dagbag import DagBag

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_dags
from tests_common.test_utils.db import clear_db_dags, clear_db_runs

if typing.TYPE_CHECKING:
from argparse import ArgumentParser
Expand All @@ -42,9 +42,15 @@
def prepare_examples():
DagBag(include_examples=True).sync_to_db()
yield
clear_db_runs()
clear_db_dags()


@pytest.fixture(autouse=True)
def clear_runs():
clear_db_runs()


@pytest.fixture(scope="module")
def parser() -> ArgumentParser:
return cli_parser.get_parser()
Expand Down Expand Up @@ -89,3 +95,34 @@ def test_cli_assets_details(parser: ArgumentParser) -> None:
"extra": {},
"aliases": [],
}


def test_cli_assets_materialize(parser: ArgumentParser) -> None:
args = parser.parse_args(["assets", "materialize", "--name=asset1_producer", "--output=json"])
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
asset_command.asset_materialize(args)

run_list = json.loads(temp_stdout.getvalue())
assert len(run_list) == 1

# No good way to statically compare these.
undeterministic = {
"dag_run_id": None,
"data_interval_end": None,
"data_interval_start": None,
"logical_date": None,
"queued_at": None,
}

assert run_list[0] | undeterministic == undeterministic | {
"conf": {},
"dag_id": "asset1_producer",
"end_date": None,
"external_trigger": "True",
"last_scheduling_decision": None,
"note": None,
"run_type": "manual",
"start_date": None,
"state": "queued",
"triggered_by": "DagRunTriggeredByType.CLI",
}

0 comments on commit a537d9c

Please sign in to comment.