From 947a6d336784b0ee0e72e48d540dbec3f7eca095 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Mon, 4 Mar 2024 14:10:18 +0800 Subject: [PATCH] Fix MappedOperator property types (#37870) --- airflow/decorators/base.py | 1 - airflow/models/baseoperator.py | 3 +- airflow/models/mappedoperator.py | 53 ++++++++++++--------- airflow/serialization/serialized_objects.py | 1 - 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 93c403e0bbe1c..51ebbce29c24c 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -459,7 +459,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. partial_kwargs=partial_kwargs, task_id=task_id, - map_index_template=partial_kwargs.pop("map_index_template", None), params=partial_params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 18d596bc4ac6b..c563b0e63f8cb 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -196,7 +196,8 @@ def partial(**kwargs): return self.class_method.__get__(cls, cls) -_PARTIAL_DEFAULTS = { +_PARTIAL_DEFAULTS: dict[str, Any] = { + "map_index_template": None, "owner": DEFAULT_OWNER, "trigger_rule": DEFAULT_TRIGGER_RULE, "depends_on_past": False, diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 1e18249a22c90..b2e85bbb7a809 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: import datetime + from typing import List import jinja2 # Slow import. import pendulum @@ -83,6 +84,8 @@ from airflow.utils.task_group import TaskGroup from airflow.utils.trigger_rule import TriggerRule + TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]] + ValidationSource = Union[Literal["expand"], Literal["partial"]] @@ -211,7 +214,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator: expand_input=expand_input, partial_kwargs=partial_kwargs, task_id=task_id, - map_index_template=partial_kwargs.pop("map_index_template", None), params=self.params, deps=MappedOperator.deps_for(self.operator_class), operator_extra_links=self.operator_class.operator_extra_links, @@ -281,7 +283,6 @@ class MappedOperator(AbstractOperator): end_date: pendulum.DateTime | None upstream_task_ids: set[str] = attr.ib(factory=set, init=False) downstream_task_ids: set[str] = attr.ib(factory=set, init=False) - map_index_template: str | None _disallow_kwargs_override: bool """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``. @@ -392,6 +393,14 @@ def owner(self) -> str: # type: ignore[override] def email(self) -> None | str | Iterable[str]: return self.partial_kwargs.get("email") + @property + def map_index_template(self) -> None | str: + return self.partial_kwargs.get("map_index_template") + + @map_index_template.setter + def map_index_template(self, value: str | None) -> None: + self.partial_kwargs["map_index_template"] = value + @property def trigger_rule(self) -> TriggerRule: return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) @@ -453,11 +462,11 @@ def wait_for_downstream(self, value: bool) -> None: self.partial_kwargs["wait_for_downstream"] = value @property - def retries(self) -> int | None: + def retries(self) -> int: return self.partial_kwargs.get("retries", DEFAULT_RETRIES) @retries.setter - def retries(self, value: int | None) -> None: + def retries(self, value: int) -> None: self.partial_kwargs["retries"] = value @property @@ -465,7 +474,7 @@ def queue(self) -> str: return self.partial_kwargs.get("queue", DEFAULT_QUEUE) @queue.setter - def queue(self, value: str | None) -> None: + def queue(self, value: str) -> None: self.partial_kwargs["queue"] = value @property @@ -473,15 +482,15 @@ def pool(self) -> str: return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME) @pool.setter - def pool(self, value: str | None) -> None: + def pool(self, value: str) -> None: self.partial_kwargs["pool"] = value @property - def pool_slots(self) -> str | None: + def pool_slots(self) -> int: return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS) @pool_slots.setter - def pool_slots(self, value: str | None) -> None: + def pool_slots(self, value: int) -> None: self.partial_kwargs["pool_slots"] = value @property @@ -505,7 +514,7 @@ def retry_delay(self) -> datetime.timedelta: return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY) @retry_delay.setter - def retry_delay(self, value: datetime.timedelta | None) -> None: + def retry_delay(self, value: datetime.timedelta) -> None: self.partial_kwargs["retry_delay"] = value @property @@ -513,7 +522,7 @@ def retry_exponential_backoff(self) -> bool: return bool(self.partial_kwargs.get("retry_exponential_backoff")) @retry_exponential_backoff.setter - def retry_exponential_backoff(self, value: bool | None) -> None: + def retry_exponential_backoff(self, value: bool) -> None: self.partial_kwargs["retry_exponential_backoff"] = value @property @@ -521,7 +530,7 @@ def priority_weight(self) -> int: # type: ignore[override] return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT) @priority_weight.setter - def priority_weight(self, value: int | None) -> None: + def priority_weight(self, value: int) -> None: self.partial_kwargs["priority_weight"] = value @property @@ -529,7 +538,7 @@ def weight_rule(self) -> str: # type: ignore[override] return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) @weight_rule.setter - def weight_rule(self, value: str | None) -> None: + def weight_rule(self, value: str) -> None: self.partial_kwargs["weight_rule"] = value @property @@ -561,43 +570,43 @@ def resources(self) -> Resources | None: return self.partial_kwargs.get("resources") @property - def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: + def on_execute_callback(self) -> TaskStateChangeCallbackAttrType: return self.partial_kwargs.get("on_execute_callback") @on_execute_callback.setter - def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None: + def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_execute_callback"] = value @property - def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: + def on_failure_callback(self) -> TaskStateChangeCallbackAttrType: return self.partial_kwargs.get("on_failure_callback") @on_failure_callback.setter - def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None: + def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_failure_callback"] = value @property - def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: + def on_retry_callback(self) -> TaskStateChangeCallbackAttrType: return self.partial_kwargs.get("on_retry_callback") @on_retry_callback.setter - def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None: + def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_retry_callback"] = value @property - def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: + def on_success_callback(self) -> TaskStateChangeCallbackAttrType: return self.partial_kwargs.get("on_success_callback") @on_success_callback.setter - def on_success_callback(self, value: TaskStateChangeCallback | None) -> None: + def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_success_callback"] = value @property - def on_skipped_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]: + def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType: return self.partial_kwargs.get("on_skipped_callback") @on_skipped_callback.setter - def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> None: + def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None: self.partial_kwargs["on_skipped_callback"] = value @property diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 82c4a6bccc5a8..f2d4aed8900d0 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -1130,7 +1130,6 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator: task_group=None, start_date=None, end_date=None, - map_index_template=None, disallow_kwargs_override=encoded_op["_disallow_kwargs_override"], expand_input_attr=encoded_op["_expand_input_attr"], )