From 1f0f9079c74c4535bdd5973724e6f7706296cffc Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Mon, 15 Apr 2024 02:08:47 -0700 Subject: [PATCH] Ensure we can serialize DatasetEventAccessor(s) properly (#38993) --- airflow/serialization/enums.py | 2 ++ airflow/serialization/serialized_objects.py | 20 +++++++++++++++++-- airflow/utils/context.pyi | 1 + docs/spelling_wordlist.txt | 1 + .../serialization/test_serialized_objects.py | 12 +++++++++++ 5 files changed, 34 insertions(+), 2 deletions(-) diff --git a/airflow/serialization/enums.py b/airflow/serialization/enums.py index 9b7cdbcc738ad..754412c653c1c 100644 --- a/airflow/serialization/enums.py +++ b/airflow/serialization/enums.py @@ -37,6 +37,8 @@ class DagAttributeTypes(str, Enum): """Enum of supported attribute types of DAG.""" DAG = "dag" + DATASET_EVENT_ACCESSORS = "dataset_event_accessors" + DATASET_EVENT_ACCESSOR = "dataset_event_accessor" OP = "operator" DATETIME = "datetime" TIMEDELTA = "timedelta" diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 8dd7465c5f63c..9f8664477b029 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -67,7 +67,7 @@ airflow_priority_weight_strategies_classes, ) from airflow.utils.code_utils import get_python_source -from airflow.utils.context import Context +from airflow.utils.context import Context, DatasetEventAccessor, DatasetEventAccessors from airflow.utils.docs import get_docs_url from airflow.utils.module_loading import import_string, qualname from airflow.utils.operator_resources import Resources @@ -534,6 +534,16 @@ def serialize( elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod): json_pod = PodGenerator.serialize_pod(var) return cls._encode(json_pod, type_=DAT.POD) + elif isinstance(var, DatasetEventAccessors): + return cls._encode( + cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined] + type_=DAT.DATASET_EVENT_ACCESSORS, + ) + elif isinstance(var, DatasetEventAccessor): + return cls._encode( + cls.serialize(var.extra, strict=strict, use_pydantic_models=use_pydantic_models), + type_=DAT.DATASET_EVENT_ACCESSOR, + ) elif isinstance(var, DAG): return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG) elif isinstance(var, Resources): @@ -663,8 +673,14 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any: d[k] = cls.deserialize(v, use_pydantic_models=True) d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this return Context(**d) - if type_ == DAT.DICT: + elif type_ == DAT.DICT: return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()} + elif type_ == DAT.DATASET_EVENT_ACCESSORS: + d = DatasetEventAccessors() # type: ignore[assignment] + d._dict = cls.deserialize(var) # type: ignore[attr-defined] + return d + elif type_ == DAT.DATASET_EVENT_ACCESSOR: + return DatasetEventAccessor(extra=cls.deserialize(var)) elif type_ == DAT.DAG: return SerializedDAG.deserialize_dag(var) elif type_ == DAT.OP: diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi index eb2cf6dd3e46f..24d31ccfc9f33 100644 --- a/airflow/utils/context.pyi +++ b/airflow/utils/context.pyi @@ -57,6 +57,7 @@ class ConnectionAccessor: def get(self, key: str, default_conn: Any = None) -> Any: ... class DatasetEventAccessor: + def __init__(self, *, extra: dict[str, Any]) -> None: ... extra: dict[str, Any] class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]): diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dcd8641d8061b..eb7dcb7f47225 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1429,6 +1429,7 @@ segmentGranularity Sendgrid sendgrid sentimentMax +ser serde serialise serializable diff --git a/tests/serialization/test_serialized_objects.py b/tests/serialization/test_serialized_objects.py index bfeff47627741..88f9165c66510 100644 --- a/tests/serialization/test_serialized_objects.py +++ b/tests/serialization/test_serialized_objects.py @@ -51,6 +51,7 @@ from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import _ENABLE_AIP_44 from airflow.utils import timezone +from airflow.utils.context import DatasetEventAccessors from airflow.utils.operator_resources import Resources from airflow.utils.pydantic import BaseModel from airflow.utils.state import DagRunState, State @@ -410,3 +411,14 @@ def test_serialized_mapped_operator_unmap(dag_maker): serialized_unmapped_task = serialized_task2.unmap(None) assert serialized_unmapped_task.dag is serialized_dag + + +def test_ser_of_dataset_event_accessor(): + # todo: (Airflow 3.0) we should force reserialization on upgrade + d = DatasetEventAccessors() + d["hi"].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict? + d["yo"].extra = {"this": "that", "the": "other"} + ser = BaseSerialization.serialize(var=d) + deser = BaseSerialization.deserialize(ser) + assert deser["hi"].extra == "blah1" + assert d["yo"].extra == {"this": "that", "the": "other"}