From 39d207be850025c414ed429ef2692a07515889b2 Mon Sep 17 00:00:00 2001 From: Karen Braganza Date: Wed, 2 Oct 2024 02:20:43 -0400 Subject: [PATCH] Check pool_slots on partial task import instead of execution (#39724) Co-authored-by: Ryan Hatter <25823361+RNHTTR@users.noreply.github.com> --- airflow/decorators/base.py | 6 ++++++ airflow/models/baseoperator.py | 5 +++++ tests/models/test_mappedoperator.py | 9 +++++++++ 3 files changed, 20 insertions(+) diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index e650c1920a870..bb9602d50c1cd 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -468,6 +468,12 @@ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) if partial_kwargs.get("pool") is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if "pool_slots" in partial_kwargs: + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) partial_kwargs["retry_delay"] = coerce_timedelta( partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 20656586ba01e..9e0c8e1e69b61 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -358,6 +358,11 @@ def partial( partial_kwargs["end_date"] = timezone.convert_to_utc(partial_kwargs["end_date"]) if partial_kwargs["pool"] is None: partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME + if partial_kwargs["pool_slots"] < 1: + dag_str = "" + if dag: + dag_str = f" in dag {dag.dag_id}" + raise ValueError(f"pool slots for {task_id}{dag_str} cannot be less than 1") partial_kwargs["retries"] = parse_retries(partial_kwargs["retries"]) partial_kwargs["retry_delay"] = coerce_timedelta(partial_kwargs["retry_delay"], key="retry_delay") if partial_kwargs["max_retry_delay"] is not None: diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 2b0cd50165c45..0571e07e671f8 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -220,6 +220,15 @@ def test_partial_on_class_invalid_ctor_args() -> None: MockOperator.partial(task_id="a", foo="bar", bar=2) +def test_partial_on_invalid_pool_slots_raises() -> None: + """Test that when we pass an invalid value to pool_slots in partial(), + + i.e. if the value is not an integer, an error is raised at import time.""" + + with pytest.raises(TypeError, match="'<' not supported between instances of 'str' and 'int'"): + MockOperator.partial(task_id="pool_slots_test", pool="test", pool_slots="a").expand(arg1=[1, 2, 3]) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @pytest.mark.parametrize( ["num_existing_tis", "expected"],