Skip to content

Commit

Permalink
Ensure we can serialize DatasetEventAccessor(s) properly (apache#38993)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstandish authored Apr 15, 2024
1 parent 081637e commit 1f0f907
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
2 changes: 2 additions & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DagAttributeTypes(str, Enum):
"""Enum of supported attribute types of DAG."""

DAG = "dag"
DATASET_EVENT_ACCESSORS = "dataset_event_accessors"
DATASET_EVENT_ACCESSOR = "dataset_event_accessor"
OP = "operator"
DATETIME = "datetime"
TIMEDELTA = "timedelta"
Expand Down
20 changes: 18 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
airflow_priority_weight_strategies_classes,
)
from airflow.utils.code_utils import get_python_source
from airflow.utils.context import Context
from airflow.utils.context import Context, DatasetEventAccessor, DatasetEventAccessors
from airflow.utils.docs import get_docs_url
from airflow.utils.module_loading import import_string, qualname
from airflow.utils.operator_resources import Resources
Expand Down Expand Up @@ -534,6 +534,16 @@ def serialize(
elif var.__class__.__name__ == "V1Pod" and _has_kubernetes() and isinstance(var, k8s.V1Pod):
json_pod = PodGenerator.serialize_pod(var)
return cls._encode(json_pod, type_=DAT.POD)
elif isinstance(var, DatasetEventAccessors):
return cls._encode(
cls.serialize(var._dict, strict=strict, use_pydantic_models=use_pydantic_models), # type: ignore[attr-defined]
type_=DAT.DATASET_EVENT_ACCESSORS,
)
elif isinstance(var, DatasetEventAccessor):
return cls._encode(
cls.serialize(var.extra, strict=strict, use_pydantic_models=use_pydantic_models),
type_=DAT.DATASET_EVENT_ACCESSOR,
)
elif isinstance(var, DAG):
return cls._encode(SerializedDAG.serialize_dag(var), type_=DAT.DAG)
elif isinstance(var, Resources):
Expand Down Expand Up @@ -663,8 +673,14 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
d[k] = cls.deserialize(v, use_pydantic_models=True)
d["task"] = d["task_instance"].task # todo: add `_encode` of Operator so we don't need this
return Context(**d)
if type_ == DAT.DICT:
elif type_ == DAT.DICT:
return {k: cls.deserialize(v, use_pydantic_models) for k, v in var.items()}
elif type_ == DAT.DATASET_EVENT_ACCESSORS:
d = DatasetEventAccessors() # type: ignore[assignment]
d._dict = cls.deserialize(var) # type: ignore[attr-defined]
return d
elif type_ == DAT.DATASET_EVENT_ACCESSOR:
return DatasetEventAccessor(extra=cls.deserialize(var))
elif type_ == DAT.DAG:
return SerializedDAG.deserialize_dag(var)
elif type_ == DAT.OP:
Expand Down
1 change: 1 addition & 0 deletions airflow/utils/context.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class ConnectionAccessor:
def get(self, key: str, default_conn: Any = None) -> Any: ...

class DatasetEventAccessor:
def __init__(self, *, extra: dict[str, Any]) -> None: ...
extra: dict[str, Any]

class DatasetEventAccessors(Mapping[str, DatasetEventAccessor]):
Expand Down
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,7 @@ segmentGranularity
Sendgrid
sendgrid
sentimentMax
ser
serde
serialise
serializable
Expand Down
12 changes: 12 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.context import DatasetEventAccessors
from airflow.utils.operator_resources import Resources
from airflow.utils.pydantic import BaseModel
from airflow.utils.state import DagRunState, State
Expand Down Expand Up @@ -410,3 +411,14 @@ def test_serialized_mapped_operator_unmap(dag_maker):

serialized_unmapped_task = serialized_task2.unmap(None)
assert serialized_unmapped_task.dag is serialized_dag


def test_ser_of_dataset_event_accessor():
# todo: (Airflow 3.0) we should force reserialization on upgrade
d = DatasetEventAccessors()
d["hi"].extra = "blah1" # todo: this should maybe be forbidden? i.e. can extra be any json or just dict?
d["yo"].extra = {"this": "that", "the": "other"}
ser = BaseSerialization.serialize(var=d)
deser = BaseSerialization.deserialize(ser)
assert deser["hi"].extra == "blah1"
assert d["yo"].extra == {"this": "that", "the": "other"}

0 comments on commit 1f0f907

Please sign in to comment.