diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 512eb189cc9a5..d59548003697a 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -364,6 +364,8 @@ def wrapper(self, *args, **kwargs): sentinel = kwargs.pop(sentinel_key, None) if sentinel: + if not getattr(cls._sentinel, "callers", None): + cls._sentinel.callers = {} cls._sentinel.callers[sentinel_key] = sentinel else: sentinel = cls._sentinel.callers.pop(f"{func.__qualname__.split('.')[0]}__sentinel", None) diff --git a/tests/models/test_baseoperatormeta.py b/tests/models/test_baseoperatormeta.py index 1e96c73aae9ce..b16ea1e68397a 100644 --- a/tests/models/test_baseoperatormeta.py +++ b/tests/models/test_baseoperatormeta.py @@ -18,6 +18,7 @@ from __future__ import annotations import datetime +import threading from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -204,3 +205,20 @@ def say_hello(**context): mock_log.warning.assert_called_once_with( "HelloWorldOperator.execute cannot be called outside TaskInstance!" ) + + def test_thread_local_executor_safeguard(self): + class TestExecutorSafeguardThread(threading.Thread): + def __init__(self): + threading.Thread.__init__(self) + self.executor_safeguard = ExecutorSafeguard() + + def run(self): + class Wrapper: + def wrapper_test_func(self, *args, **kwargs): + print("test") + + wrap_func = self.executor_safeguard.decorator(Wrapper.wrapper_test_func) + wrap_func(Wrapper(), Wrapper__sentinel="abc") + + # Test thread local caller value is set properly + TestExecutorSafeguardThread().start()