Skip to content

Commit

Permalink
Refactor dataset class inheritance (apache#37590)
Browse files Browse the repository at this point in the history
* Refactor DatasetAll and DatasetAny inheritance

They are moved from airflow.models.datasets to airflow.datasets since
the intention is to use them with Dataset, not DatasetModel. It is more
natural for users to import from the latter module instead.

A new (abstract) base class is added for the two classes, plus the OG
Dataset class, to inherit from. This allows us to replace a few
isinstance checks with simple molymorphism and make the logic a bit
simpler.

---------

Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
3 people authored Feb 22, 2024
1 parent 810fb5f commit 0900055
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 77 deletions.
64 changes: 60 additions & 4 deletions airflow/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,35 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import os
from typing import Any, ClassVar
from typing import Any, Callable, ClassVar, Iterable, Iterator, Protocol, runtime_checkable
from urllib.parse import urlsplit

import attr

__all__ = ["Dataset", "DatasetAll", "DatasetAny"]


@runtime_checkable
class BaseDatasetEventInput(Protocol):
"""Protocol for all dataset triggers to use in ``DAG(schedule=...)``.
:meta private:
"""

def evaluate(self, statuses: dict[str, bool]) -> bool:
raise NotImplementedError

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
raise NotImplementedError


@attr.define()
class Dataset(os.PathLike):
"""A Dataset is used for marking data dependencies between workflows."""
class Dataset(os.PathLike, BaseDatasetEventInput):
"""A representation of data dependencies between workflows."""

uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
extra: dict[str, Any] | None = None
Expand All @@ -44,7 +61,7 @@ def _check_uri(self, attr, uri: str):
if parsed.scheme and parsed.scheme.lower() == "airflow":
raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")

def __fspath__(self):
def __fspath__(self) -> str:
return self.uri

def __eq__(self, other):
Expand All @@ -55,3 +72,42 @@ def __eq__(self, other):

def __hash__(self):
return hash(self.uri)

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
yield self.uri, self

def evaluate(self, statuses: dict[str, bool]) -> bool:
return statuses.get(self.uri, False)


class _DatasetBooleanCondition(BaseDatasetEventInput):
"""Base class for dataset boolean logic."""

agg_func: Callable[[Iterable], bool]

def __init__(self, *objects: BaseDatasetEventInput) -> None:
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(x.evaluate(statuses=statuses) for x in self.objects)

def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
seen = set() # We want to keep the first instance.
for o in self.objects:
for k, v in o.iter_datasets():
if k in seen:
continue
yield k, v
seen.add(k)


class DatasetAny(_DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "and" relationship."""

agg_func = any


class DatasetAll(_DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "or" relationship."""

agg_func = all
26 changes: 9 additions & 17 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from airflow import settings, utils
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf as airflow_conf, secrets_backend_list
from airflow.datasets import BaseDatasetEventInput, Dataset, DatasetAll
from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowDagInconsistent,
Expand All @@ -98,13 +99,7 @@
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import RUN_ID_REGEX, DagRun
from airflow.models.dataset import (
DatasetAll,
DatasetAny,
DatasetBooleanCondition,
DatasetDagRunQueue,
DatasetModel,
)
from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import (
Context,
Expand Down Expand Up @@ -150,7 +145,6 @@
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session

from airflow.datasets import Dataset
from airflow.decorators import TaskDecoratorCollection
from airflow.models.dagbag import DagBag
from airflow.models.operator import Operator
Expand All @@ -174,7 +168,7 @@
# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
# See also: https://discuss.python.org/t/9126/7
ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]]
ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, BaseDatasetEventInput, Collection["Dataset"]]

SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]

Expand Down Expand Up @@ -586,12 +580,10 @@ def __init__(

self.timetable: Timetable
self.schedule_interval: ScheduleInterval
self.dataset_triggers: DatasetBooleanCondition | None = None
if isinstance(schedule, (DatasetAll, DatasetAny)):
self.dataset_triggers: BaseDatasetEventInput | None = None
if isinstance(schedule, BaseDatasetEventInput):
self.dataset_triggers = schedule
if isinstance(schedule, Collection) and not isinstance(schedule, str):
from airflow.datasets import Dataset

elif isinstance(schedule, Collection) and not isinstance(schedule, str):
if not all(isinstance(x, Dataset) for x in schedule):
raise ValueError("All elements in 'schedule' should be datasets")
self.dataset_triggers = DatasetAll(*schedule)
Expand Down Expand Up @@ -3181,7 +3173,7 @@ def bulk_write_to_db(
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
else:
for dataset in dag.dataset_triggers.all_datasets().values():
for _, dataset in dag.dataset_triggers.iter_datasets():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
Expand Down Expand Up @@ -3793,14 +3785,14 @@ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[
"""
from airflow.models.serialized_dag import SerializedDagModel

def dag_ready(dag_id: str, cond: DatasetBooleanCondition, statuses: dict) -> bool | None:
def dag_ready(dag_id: str, cond: BaseDatasetEventInput, statuses: dict) -> bool | None:
# if dag was serialized before 2.9 and we *just* upgraded,
# we may be dealing with old version. In that case,
# just wait for the dag to be reserialized.
try:
return cond.evaluate(statuses)
except AttributeError:
log.warning("dag '%s' has old serialization; skipping dag run creation.", dag_id)
log.warning("dag '%s' has old serialization; skipping DAG run creation.", dag_id)
return None

# this loads all the DDRQ records.... may need to limit num dags
Expand Down
47 changes: 0 additions & 47 deletions airflow/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

from typing import Callable, Iterable
from urllib.parse import urlsplit

import sqlalchemy_jsonfield
Expand Down Expand Up @@ -337,49 +336,3 @@ def __repr__(self) -> str:
]:
args.append(f"{attr}={getattr(self, attr)!r}")
return f"{self.__class__.__name__}({', '.join(args)})"


class DatasetBooleanCondition:
"""
Base class for boolean logic for dataset triggers.
:meta private:
"""

agg_func: Callable[[Iterable], bool]

def __init__(self, *objects) -> None:
self.objects = objects

def evaluate(self, statuses: dict[str, bool]) -> bool:
return self.agg_func(self.eval_one(x, statuses) for x in self.objects)

def eval_one(self, obj: Dataset | DatasetAny | DatasetAll, statuses) -> bool:
if isinstance(obj, Dataset):
return statuses.get(obj.uri, False)
return obj.evaluate(statuses=statuses)

def all_datasets(self) -> dict[str, Dataset]:
uris = {}
for x in self.objects:
if isinstance(x, Dataset):
if x.uri not in uris:
uris[x.uri] = x
else:
# keep the first instance
for k, v in x.all_datasets().items():
if k not in uris:
uris[k] = v
return uris


class DatasetAny(DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "and" relationship."""

agg_func = any


class DatasetAll(DatasetBooleanCondition):
"""Use to combine datasets schedule references in an "or" relationship."""

agg_func = all
5 changes: 2 additions & 3 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@

from airflow.compat.functools import cache
from airflow.configuration import conf
from airflow.datasets import Dataset
from airflow.datasets import Dataset, DatasetAll, DatasetAny
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, SerializationError
from airflow.jobs.job import Job
from airflow.models.baseoperator import BaseOperator
from airflow.models.connection import Connection
from airflow.models.dag import DAG, DagModel, create_timetable
from airflow.models.dagrun import DagRun
from airflow.models.dataset import DatasetAll, DatasetAny
from airflow.models.expandinput import EXPAND_INPUT_EMPTY, create_expand_input, get_map_type_key
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
Expand Down Expand Up @@ -788,7 +787,7 @@ def detect_dag_dependencies(dag: DAG | None) -> Iterable[DagDependency]:
return
if not dag.dataset_triggers:
return
for uri in dag.dataset_triggers.all_datasets().keys():
for uri, _ in dag.dataset_triggers.iter_datasets():
yield DagDependency(
source="dataset",
target=dag.dag_id,
Expand Down
13 changes: 9 additions & 4 deletions airflow/timetables/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import typing

from airflow.datasets import BaseDatasetEventInput, DatasetAll
from airflow.exceptions import AirflowTimetableInvalid
from airflow.models.dataset import DatasetAll, DatasetBooleanCondition
from airflow.timetables.simple import DatasetTriggeredTimetable as DatasetTriggeredSchedule
from airflow.utils.types import DagRunType

Expand All @@ -36,9 +36,14 @@
class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
"""Combine time-based scheduling with event-based scheduling."""

def __init__(self, timetable: Timetable, datasets: Collection[Dataset] | DatasetBooleanCondition) -> None:
def __init__(
self,
*,
timetable: Timetable,
datasets: Collection[Dataset] | BaseDatasetEventInput,
) -> None:
self.timetable = timetable
if isinstance(datasets, DatasetBooleanCondition):
if isinstance(datasets, BaseDatasetEventInput):
self.datasets = datasets
else:
self.datasets = DatasetAll(*datasets)
Expand Down Expand Up @@ -70,7 +75,7 @@ def serialize(self) -> dict[str, typing.Any]:
def validate(self) -> None:
if isinstance(self.timetable, DatasetTriggeredSchedule):
raise AirflowTimetableInvalid("cannot nest dataset timetables")
if not isinstance(self.datasets, DatasetBooleanCondition):
if not isinstance(self.datasets, BaseDatasetEventInput):
raise AirflowTimetableInvalid("all elements in 'datasets' must be datasets")

@property
Expand Down
4 changes: 2 additions & 2 deletions tests/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import pytest
from sqlalchemy.sql import select

from airflow.datasets import Dataset
from airflow.models.dataset import DatasetAll, DatasetAny, DatasetDagRunQueue, DatasetModel
from airflow.datasets import Dataset, DatasetAll, DatasetAny
from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG
Expand Down

0 comments on commit 0900055

Please sign in to comment.