Skip to content

Commit

Permalink
Fix MappedOperator property types (apache#37870)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Mar 4, 2024
1 parent 1b4b73e commit 947a6d3
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 25 deletions.
1 change: 0 additions & 1 deletion airflow/decorators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
map_index_template=partial_kwargs.pop("map_index_template", None),
params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
Expand Down
3 changes: 2 additions & 1 deletion airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ def partial(**kwargs):
return self.class_method.__get__(cls, cls)


_PARTIAL_DEFAULTS = {
_PARTIAL_DEFAULTS: dict[str, Any] = {
"map_index_template": None,
"owner": DEFAULT_OWNER,
"trigger_rule": DEFAULT_TRIGGER_RULE,
"depends_on_past": False,
Expand Down
53 changes: 31 additions & 22 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

if TYPE_CHECKING:
import datetime
from typing import List

import jinja2 # Slow import.
import pendulum
Expand All @@ -83,6 +84,8 @@
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule

TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, List[TaskStateChangeCallback]]

ValidationSource = Union[Literal["expand"], Literal["partial"]]


Expand Down Expand Up @@ -211,7 +214,6 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
map_index_template=partial_kwargs.pop("map_index_template", None),
params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
Expand Down Expand Up @@ -281,7 +283,6 @@ class MappedOperator(AbstractOperator):
end_date: pendulum.DateTime | None
upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
map_index_template: str | None

_disallow_kwargs_override: bool
"""Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
Expand Down Expand Up @@ -392,6 +393,14 @@ def owner(self) -> str: # type: ignore[override]
def email(self) -> None | str | Iterable[str]:
return self.partial_kwargs.get("email")

@property
def map_index_template(self) -> None | str:
return self.partial_kwargs.get("map_index_template")

@map_index_template.setter
def map_index_template(self, value: str | None) -> None:
self.partial_kwargs["map_index_template"] = value

@property
def trigger_rule(self) -> TriggerRule:
return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE)
Expand Down Expand Up @@ -453,35 +462,35 @@ def wait_for_downstream(self, value: bool) -> None:
self.partial_kwargs["wait_for_downstream"] = value

@property
def retries(self) -> int | None:
def retries(self) -> int:
return self.partial_kwargs.get("retries", DEFAULT_RETRIES)

@retries.setter
def retries(self, value: int | None) -> None:
def retries(self, value: int) -> None:
self.partial_kwargs["retries"] = value

@property
def queue(self) -> str:
return self.partial_kwargs.get("queue", DEFAULT_QUEUE)

@queue.setter
def queue(self, value: str | None) -> None:
def queue(self, value: str) -> None:
self.partial_kwargs["queue"] = value

@property
def pool(self) -> str:
return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)

@pool.setter
def pool(self, value: str | None) -> None:
def pool(self, value: str) -> None:
self.partial_kwargs["pool"] = value

@property
def pool_slots(self) -> str | None:
def pool_slots(self) -> int:
return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)

@pool_slots.setter
def pool_slots(self, value: str | None) -> None:
def pool_slots(self, value: int) -> None:
self.partial_kwargs["pool_slots"] = value

@property
Expand All @@ -505,31 +514,31 @@ def retry_delay(self) -> datetime.timedelta:
return self.partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY)

@retry_delay.setter
def retry_delay(self, value: datetime.timedelta | None) -> None:
def retry_delay(self, value: datetime.timedelta) -> None:
self.partial_kwargs["retry_delay"] = value

@property
def retry_exponential_backoff(self) -> bool:
return bool(self.partial_kwargs.get("retry_exponential_backoff"))

@retry_exponential_backoff.setter
def retry_exponential_backoff(self, value: bool | None) -> None:
def retry_exponential_backoff(self, value: bool) -> None:
self.partial_kwargs["retry_exponential_backoff"] = value

@property
def priority_weight(self) -> int: # type: ignore[override]
return self.partial_kwargs.get("priority_weight", DEFAULT_PRIORITY_WEIGHT)

@priority_weight.setter
def priority_weight(self, value: int | None) -> None:
def priority_weight(self, value: int) -> None:
self.partial_kwargs["priority_weight"] = value

@property
def weight_rule(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)

@weight_rule.setter
def weight_rule(self, value: str | None) -> None:
def weight_rule(self, value: str) -> None:
self.partial_kwargs["weight_rule"] = value

@property
Expand Down Expand Up @@ -561,43 +570,43 @@ def resources(self) -> Resources | None:
return self.partial_kwargs.get("resources")

@property
def on_execute_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_execute_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_execute_callback")

@on_execute_callback.setter
def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_execute_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_execute_callback"] = value

@property
def on_failure_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_failure_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_failure_callback")

@on_failure_callback.setter
def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_failure_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_failure_callback"] = value

@property
def on_retry_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_retry_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_retry_callback")

@on_retry_callback.setter
def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_retry_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_retry_callback"] = value

@property
def on_success_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_success_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_success_callback")

@on_success_callback.setter
def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_success_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_success_callback"] = value

@property
def on_skipped_callback(self) -> None | TaskStateChangeCallback | list[TaskStateChangeCallback]:
def on_skipped_callback(self) -> TaskStateChangeCallbackAttrType:
return self.partial_kwargs.get("on_skipped_callback")

@on_skipped_callback.setter
def on_skipped_callback(self, value: TaskStateChangeCallback | None) -> None:
def on_skipped_callback(self, value: TaskStateChangeCallbackAttrType) -> None:
self.partial_kwargs["on_skipped_callback"] = value

@property
Expand Down
1 change: 0 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,6 @@ def deserialize_operator(cls, encoded_op: dict[str, Any]) -> Operator:
task_group=None,
start_date=None,
end_date=None,
map_index_template=None,
disallow_kwargs_override=encoded_op["_disallow_kwargs_override"],
expand_input_attr=encoded_op["_expand_input_attr"],
)
Expand Down

0 comments on commit 947a6d3

Please sign in to comment.