From 47cdb84b351bd75c8232406d5673405dcb7efd64 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Wed, 20 Nov 2024 17:04:37 +0000 Subject: [PATCH] Rewite LocalExecutor to be simpler, and to shutdown cleanly on Python 3.10+ (#23944) Something changed between Python 3.7 and 3.10 meaning that a limited parallelism LocalExecutor scheduler now doesn't shutdown cleanly on receiving a signal. On closer inspection of the limited vs unlimited path it apepars to me that the code was "over-generalized" and the entire concept of `self.impl` has been removed hopefully making this code much more direct and easier to understand. The key things are now: - When a task needs to be run, we send the message on a mp.SimpleQueue object, and increment an internal counter. (We use our own counter, not qsize method as that is not portable) - Inside _check_workers we see if we think there are any outstanding messages, and create a worker if there are. The reason we do this is the on macOS (where the default mp start method is "spawn") a process will be started via `exeucte_async`, but it will take a second or two to pull the message of the queue, by which time the scheduler will have called `executor.sync()` again, meaning we'd over create workers (but never above the limit). Avoiding that case is why we keep the internal `_unread_messages` counter -- `self.activity_queue.empty()` would return False when the worker is booting up. - We remove the entire use of `multiprocessing.Manager` -- it doesn't seem to do anything other than create queue objects but for our use it just adds complexity to understanding - Almost as a side-effect we now only create worker subprocesses on demand, instead of pre-launching them. We do not currently shut down idle processes, though adding it would be quite straight forward if we wanted to in the future This branch name was "rewrite-local-exexc-concurrentfutures" (sic) as when originally opened in 2022 for 3.10 that was the plan. However since then 3.12 has come out and it now starts issuing warnings when Fork and threads are used, and concurrent.futures uses a thread internally, so a different approach was used. --- airflow/executors/local_executor.py | 504 ++++++++++--------------- tests/executors/test_local_executor.py | 113 +++--- 2 files changed, 252 insertions(+), 365 deletions(-) diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py index cb19b57a81501..3b8b52176db5b 100644 --- a/airflow/executors/local_executor.py +++ b/airflow/executors/local_executor.py @@ -25,192 +25,146 @@ from __future__ import annotations -import contextlib +import ctypes import logging +import multiprocessing +import multiprocessing.sharedctypes import os import subprocess -from abc import abstractmethod -from multiprocessing import Manager, Process -from queue import Empty +from multiprocessing import Queue, SimpleQueue from typing import TYPE_CHECKING, Any, Optional, Tuple -from setproctitle import getproctitle, setproctitle +from setproctitle import setproctitle from airflow import settings -from airflow.exceptions import AirflowException from airflow.executors.base_executor import PARALLELISM, BaseExecutor -from airflow.traces.tracer import Trace, add_span +from airflow.traces.tracer import add_span from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager -from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: - from multiprocessing.managers import SyncManager - from queue import Queue - from airflow.executors.base_executor import CommandType - from airflow.models.taskinstance import TaskInstanceStateType from airflow.models.taskinstancekey import TaskInstanceKey # This is a work to be executed by a worker. # It can Key and Command - but it can also be None, None which is actually a # "Poison Pill" - worker seeing Poison Pill should take the pill and ... die instantly. - ExecutorWorkType = Tuple[Optional[TaskInstanceKey], Optional[CommandType]] + ExecutorWorkType = Optional[Tuple[TaskInstanceKey, CommandType]] + TaskInstanceStateType = Tuple[TaskInstanceKey, TaskInstanceState, Optional[Exception]] -class LocalWorkerBase(Process, LoggingMixin): - """ - LocalWorkerBase implementation to run airflow commands. +def _run_worker( + logger_name: str, + input: SimpleQueue[ExecutorWorkType], + output: Queue[TaskInstanceStateType], + unread_messages: multiprocessing.sharedctypes.Synchronized[int], +): + import signal - Executes the given command and puts the result into a result queue when done, terminating execution. + # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion + signal.signal(signal.SIGINT, signal.SIG_IGN) - :param result_queue: the queue to store result state - """ + log = logging.getLogger(logger_name) - def __init__(self, result_queue: Queue[TaskInstanceStateType]): - super().__init__(target=self.do_work) - self.daemon: bool = True - self.result_queue: Queue[TaskInstanceStateType] = result_queue + # We know we've just started a new process, so lets disconnect from the metadata db now + settings.engine.pool.dispose() + settings.engine.dispose() - def run(self): - # We know we've just started a new process, so lets disconnect from the metadata db now - settings.engine.pool.dispose() - settings.engine.dispose() - setproctitle("airflow worker -- LocalExecutor") - return super().run() + setproctitle("airflow worker -- LocalExecutor: ") - @add_span - def execute_work(self, key: TaskInstanceKey, command: CommandType) -> None: - """ - Execute command received and stores result state in queue. - - :param key: the key to identify the task instance - :param command: the command to execute - """ - if key is None: - return - - self.log.info("%s running %s", self.__class__.__name__, command) - setproctitle(f"airflow worker -- LocalExecutor: {command}") - dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) - with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): - if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - state = self._execute_work_in_subprocess(command) - else: - state = self._execute_work_in_fork(command) - - self.result_queue.put((key, state)) - # Remove the command since the worker is done executing the task - setproctitle("airflow worker -- LocalExecutor") - - @add_span - def _execute_work_in_subprocess(self, command: CommandType) -> TaskInstanceState: + while True: try: - subprocess.check_call(command, close_fds=True) - return TaskInstanceState.SUCCESS - except subprocess.CalledProcessError as e: - self.log.error("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - - @add_span - def _execute_work_in_fork(self, command: CommandType) -> TaskInstanceState: - pid = os.fork() - if pid: - # In parent, wait for the child - pid, ret = os.waitpid(pid, 0) - return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED + item = input.get() + except EOFError: + log.info( + "Failed to read tasks from the task queue because the other " + "end has closed the connection. Terminating worker %s.", + multiprocessing.current_process().name, + ) + break + + if item is None: + # Received poison pill, no more tasks to run + return - from airflow.sentry import Sentry + # Decrement this as soon as we pick up a message off the queue + with unread_messages: + unread_messages.value -= 1 - ret = 1 + (key, command) = item try: - import signal - - from airflow.cli.cli_parser import get_parser - - signal.signal(signal.SIGINT, signal.SIG_DFL) - signal.signal(signal.SIGTERM, signal.SIG_DFL) - signal.signal(signal.SIGUSR2, signal.SIG_DFL) - - parser = get_parser() - # [1:] - remove "airflow" from the start of the command - args = parser.parse_args(command[1:]) - args.shut_down_logging = False + state = _execute_work(log, key, command) - setproctitle(f"airflow task supervisor: {command}") - - args.func(args) - ret = 0 - return TaskInstanceState.SUCCESS + output.put((key, state, None)) except Exception as e: - self.log.exception("Failed to execute task %s.", e) - return TaskInstanceState.FAILED - finally: - Sentry.flush() - logging.shutdown() - os._exit(ret) - - @abstractmethod - def do_work(self): - """Execute tasks; called in the subprocess.""" - raise NotImplementedError() + output.put((key, TaskInstanceState.FAILED, e)) -class LocalWorker(LocalWorkerBase): +def _execute_work(log: logging.Logger, key: TaskInstanceKey, command: CommandType) -> TaskInstanceState: """ - Local worker that executes the task. + Execute command received and stores result state in queue. - :param result_queue: queue where results of the tasks are put. - :param key: key identifying task instance - :param command: Command to execute + :param key: the key to identify the task instance + :param command: the command to execute """ + setproctitle(f"airflow worker -- LocalExecutor: {command}") + dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command) + try: + with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id): + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + return _execute_work_in_subprocess(log, command) + else: + return _execute_work_in_fork(log, command) + finally: + # Remove the command since the worker is done executing the task + setproctitle("airflow worker -- LocalExecutor: ") - def __init__( - self, result_queue: Queue[TaskInstanceStateType], key: TaskInstanceKey, command: CommandType - ): - super().__init__(result_queue) - self.key: TaskInstanceKey = key - self.command: CommandType = command - @add_span - def do_work(self) -> None: - self.execute_work(key=self.key, command=self.command) +def _execute_work_in_subprocess(log: logging.Logger, command: CommandType) -> TaskInstanceState: + try: + subprocess.check_call(command, close_fds=True) + return TaskInstanceState.SUCCESS + except subprocess.CalledProcessError as e: + log.error("Failed to execute task %s.", e) + return TaskInstanceState.FAILED -class QueuedLocalWorker(LocalWorkerBase): - """ - LocalWorker implementation that is waiting for tasks from a queue. +def _execute_work_in_fork(log: logging.Logger, command: CommandType) -> TaskInstanceState: + pid = os.fork() + if pid: + # In parent, wait for the child + pid, ret = os.waitpid(pid, 0) + return TaskInstanceState.SUCCESS if ret == 0 else TaskInstanceState.FAILED - Will continue executing commands as they become available in the queue. - It will terminate execution once the poison token is found. + from airflow.sentry import Sentry - :param task_queue: queue from which worker reads tasks - :param result_queue: queue where worker puts results after finishing tasks - """ + ret = 1 + try: + import signal - def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[TaskInstanceStateType]): - super().__init__(result_queue=result_queue) - self.task_queue = task_queue + from airflow.cli.cli_parser import get_parser - @add_span - def do_work(self) -> None: - while True: - try: - key, command = self.task_queue.get() - except EOFError: - self.log.info( - "Failed to read tasks from the task queue because the other " - "end has closed the connection. Terminating worker %s.", - self.name, - ) - break - try: - if key is None or command is None: - # Received poison pill, no more tasks to run - break - self.execute_work(key=key, command=command) - finally: - self.task_queue.task_done() + signal.signal(signal.SIGINT, signal.SIG_IGN) + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGUSR2, signal.SIG_DFL) + + parser = get_parser() + # [1:] - remove "airflow" from the start of the command + args = parser.parse_args(command[1:]) + args.shut_down_logging = False + + setproctitle(f"airflow task supervisor: {command}") + + args.func(args) + ret = 0 + return TaskInstanceState.SUCCESS + except Exception as e: + log.exception("Failed to execute task %s.", e) + return TaskInstanceState.FAILED + finally: + Sentry.flush() + logging.shutdown() + os._exit(ret) class LocalExecutor(BaseExecutor): @@ -226,173 +180,28 @@ class LocalExecutor(BaseExecutor): serve_logs: bool = True + activity_queue: SimpleQueue[ExecutorWorkType] + result_queue: SimpleQueue[TaskInstanceStateType] + workers: dict[int, multiprocessing.Process] + _unread_messages: multiprocessing.sharedctypes.Synchronized[int] + def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism) if self.parallelism < 0: - raise AirflowException("parallelism must be bigger than or equal to 0") - self.manager: SyncManager | None = None - self.result_queue: Queue[TaskInstanceStateType] | None = None - self.workers: list[QueuedLocalWorker] = [] - self.workers_used: int = 0 - self.workers_active: int = 0 - self.impl: None | (LocalExecutor.UnlimitedParallelism | LocalExecutor.LimitedParallelism) = None - - class UnlimitedParallelism: - """ - Implement LocalExecutor with unlimited parallelism, starting one process per command executed. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - - def start(self) -> None: - """Start the executor.""" - self.executor.workers_used = 0 - self.executor.workers_active = 0 - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: Name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.executor.result_queue - - span = Trace.get_current_span() - if span.is_recording(): - span.set_attributes( - { - "dag_id": key.dag_id, - "run_id": key.run_id, - "task_id": key.task_id, - "try_number": key.try_number, - "commands_to_run": str(command), - } - ) - - local_worker = LocalWorker(self.executor.result_queue, key=key, command=command) - self.executor.workers_used += 1 - self.executor.workers_active += 1 - local_worker.start() - - def sync(self) -> None: - """Sync will get called periodically by the heartbeat method.""" - if not self.executor.result_queue: - raise AirflowException("Executor should be started first") - while not self.executor.result_queue.empty(): - results = self.executor.result_queue.get() - self.executor.change_state(*results) - self.executor.workers_active -= 1 - - def end(self) -> None: - """Wait synchronously for the previously submitted job to complete.""" - while self.executor.workers_active > 0: - self.executor.sync() - - class LimitedParallelism: - """ - Implements LocalExecutor with limited parallelism. - - Uses a task queue to coordinate work distribution. - - :param executor: the executor instance to implement. - """ - - def __init__(self, executor: LocalExecutor): - self.executor: LocalExecutor = executor - self.queue: Queue[ExecutorWorkType] | None = None - - def start(self) -> None: - """Start limited parallelism implementation.""" - if TYPE_CHECKING: - assert self.executor.manager - assert self.executor.result_queue - - self.queue = self.executor.manager.Queue() - self.executor.workers = [ - QueuedLocalWorker(self.queue, self.executor.result_queue) - for _ in range(self.executor.parallelism) - ] - - self.executor.workers_used = len(self.executor.workers) - - for worker in self.executor.workers: - worker.start() - - @add_span - def execute_async( - self, - key: TaskInstanceKey, - command: CommandType, - queue: str | None = None, - executor_config: Any | None = None, - ) -> None: - """ - Execute task asynchronously. - - :param key: the key to identify the task instance - :param command: the command to execute - :param queue: name of the queue - :param executor_config: configuration for the executor - """ - if TYPE_CHECKING: - assert self.queue - - self.queue.put((key, command)) - - def sync(self): - """Sync will get called periodically by the heartbeat method.""" - with contextlib.suppress(Empty): - while True: - results = self.executor.result_queue.get_nowait() - try: - self.executor.change_state(*results) - finally: - self.executor.result_queue.task_done() - - def end(self): - """ - End the executor. - - Sends the poison pill to all workers. - """ - for _ in self.executor.workers: - self.queue.put((None, None)) - - # Wait for commands to finish - self.queue.join() - self.executor.sync() + raise ValueError("parallelism must be greater than or equal to 0") def start(self) -> None: """Start the executor.""" - old_proctitle = getproctitle() - setproctitle("airflow executor -- LocalExecutor") - self.manager = Manager() - setproctitle(old_proctitle) - self.result_queue = self.manager.Queue() - self.workers = [] - self.workers_used = 0 - self.workers_active = 0 - self.impl = ( - LocalExecutor.UnlimitedParallelism(self) - if self.parallelism == 0 - else LocalExecutor.LimitedParallelism(self) - ) + # We delay opening these queues until the start method mostly for unit tests. ExecutorLoader caches + # instances, so each test reusues the same instance! (i.e. test 1 runs, closes the queues, then test 2 + # comes back and gets the same LocalExecutor instance, so we have to open new here.) + self.activity_queue = SimpleQueue() + self.result_queue = SimpleQueue() + self.workers = {} - self.impl.start() + # Mypy sees this value as `SynchronizedBase[c_uint]`, but that isn't the right runtime type behaviour + # (it looks like an int to python) + self._unread_messages = multiprocessing.Value(ctypes.c_uint) # type: ignore[assignment] @add_span def execute_async( @@ -403,32 +212,97 @@ def execute_async( executor_config: Any | None = None, ) -> None: """Execute asynchronously.""" - if TYPE_CHECKING: - assert self.impl - self.validate_airflow_tasks_run_command(command) + self.activity_queue.put((key, command)) + with self._unread_messages: + self._unread_messages.value += 1 + self._check_workers(can_start=True) + + def _check_workers(self, can_start: bool = True): + # Reap any dead workers + to_remove = set() + for pid, proc in self.workers.items(): + if not proc.is_alive(): + to_remove.add(pid) + proc.close() + + if to_remove: + self.workers = {pid: proc for pid, proc in self.workers.items() if pid not in to_remove} + + with self._unread_messages: + num_outstanding = self._unread_messages.value + + if num_outstanding <= 0 or self.activity_queue.empty(): + # Nothing to do. Future enhancement if someone wants: shut down workers that have been idle for N + # seconds + return - self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config) + # If we're using spawn in multiprocessing (default on macOS now) to start tasks, this can get called a + # via `sync()` a few times before the spawned process actually starts picking up messages. Try not to + # create too much + need_more_workers = len(self.workers) < num_outstanding + if need_more_workers and (self.parallelism == 0 or len(self.workers) < self.parallelism): + # This only creates one worker, which is fine as we call this directly after putting a message on + # activity_queue in execute_async + self._spawn_worker() + + def _spawn_worker(self): + p = multiprocessing.Process( + target=_run_worker, + kwargs={ + "logger_name": self.log.name, + "input": self.activity_queue, + "output": self.result_queue, + "unread_messages": self._unread_messages, + }, + ) + p.start() + if TYPE_CHECKING: + assert p.pid # Since we've called start + self.workers[p.pid] = p def sync(self) -> None: """Sync will get called periodically by the heartbeat method.""" - if TYPE_CHECKING: - assert self.impl + self._read_results() + self._check_workers() - self.impl.sync() + def _read_results(self): + while not self.result_queue.empty(): + key, state, exc = self.result_queue.get() + + if exc: + # TODO: This needs a better stacktrace, it appears from here + if hasattr(exc, "add_note"): + exc.add_note("(This stacktrace is incorrect -- the exception came from a subprocess)") + raise exc + + self.change_state(key, state) def end(self) -> None: """End the executor.""" - if TYPE_CHECKING: - assert self.impl - assert self.manager - self.log.info( "Shutting down LocalExecutor" "; waiting for running tasks to finish. Signal again if you don't want to wait." ) - self.impl.end() - self.manager.shutdown() + + # We can't tell which proc will pick which close message up, so we send all the messages, and then + # wait on all the procs + + for proc in self.workers.values(): + # Send the shutdown message once for each alive worker + if proc.is_alive(): + self.activity_queue.put(None) + + for proc in self.workers.values(): + if proc.is_alive(): + proc.join() + proc.close() + + # Process any extra results before closing + self._read_results() + + self.activity_queue.close() + self.result_queue.close() def terminate(self): """Terminate the executor is not doing anything.""" diff --git a/tests/executors/test_local_executor.py b/tests/executors/test_local_executor.py index f6ba8dca464d3..2545ceb7e705d 100644 --- a/tests/executors/test_local_executor.py +++ b/tests/executors/test_local_executor.py @@ -18,10 +18,13 @@ from __future__ import annotations import datetime +import multiprocessing +import os import subprocess from unittest import mock import pytest +from kgb import spy_on from airflow import settings from airflow.exceptions import AirflowException @@ -30,6 +33,12 @@ pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] +# Runtime is fine, we just can't run the tests on macOS +skip_spawn_mp_start = pytest.mark.skipif( + multiprocessing.get_context().get_start_method() == "spawn", + reason="mock patching in test don't work with 'spawn' mode (default on macOS)", +) + class TestLocalExecutor: TEST_SUCCESS_COMMANDS = 5 @@ -44,85 +53,72 @@ def test_serve_logs_default_value(self): assert LocalExecutor.serve_logs @mock.patch("airflow.executors.local_executor.subprocess.check_call") - def execution_parallelism_subprocess(self, mock_check_call, parallelism=0): - success_command = ["airflow", "tasks", "run", "true", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "false", "task_id", "2020-10-07"] + @mock.patch("airflow.cli.commands.task_command.task_run") + def _test_execute(self, mock_run, mock_check_call, parallelism=1): + success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] + fail_command = ["airflow", "tasks", "run", "failure", "task_id", "2020-10-07"] + # We just mock both styles here, only one will be hit though def fake_execute_command(command, close_fds=True): if command != success_command: raise subprocess.CalledProcessError(returncode=1, cmd=command) else: return 0 - mock_check_call.side_effect = fake_execute_command - - self._test_execute(parallelism, success_command, fail_command) - - @mock.patch("airflow.cli.commands.task_command.task_run") - def execution_parallelism_fork(self, mock_run, parallelism=0): - success_command = ["airflow", "tasks", "run", "success", "some_parameter", "2020-10-07"] - fail_command = ["airflow", "tasks", "run", "failure", "some_parameter", "2020-10-07"] - def fake_task_run(args): if args.dag_id != "success": raise AirflowException("Simulate failed task") + mock_check_call.side_effect = fake_execute_command mock_run.side_effect = fake_task_run - self._test_execute(parallelism, success_command, fail_command) - - def _test_execute(self, parallelism, success_command, fail_command): executor = LocalExecutor(parallelism=parallelism) executor.start() success_key = "success {}" assert executor.result_queue.empty() - logical_date = datetime.datetime.now() - for i in range(self.TEST_SUCCESS_COMMANDS): - key_id, command = success_key.format(i), success_command - key = key_id, "fake_ti", logical_date, 0 - executor.running.add(key) - executor.execute_async(key=key, command=command) + with spy_on(executor._spawn_worker) as spawn_worker: + run_id = "manual_" + datetime.datetime.now().isoformat() + for i in range(self.TEST_SUCCESS_COMMANDS): + key_id, command = success_key.format(i), success_command + key = key_id, "fake_ti", run_id, 0 + executor.running.add(key) + executor.execute_async(key=key, command=command) - fail_key = "fail", "fake_ti", logical_date, 0 - executor.running.add(fail_key) - executor.execute_async(key=fail_key, command=fail_command) + fail_key = "fail", "fake_ti", run_id, 0 + executor.running.add(fail_key) + executor.execute_async(key=fail_key, command=fail_command) + + executor.end() + + expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism + # Depending on how quickly the tasks run, we might not need to create all the workers we could + assert 1 <= len(spawn_worker.calls) <= expected - executor.end() # By that time Queues are already shutdown so we cannot check if they are empty assert len(executor.running) == 0 + assert executor._unread_messages.value == 0 for i in range(self.TEST_SUCCESS_COMMANDS): key_id = success_key.format(i) - key = key_id, "fake_ti", logical_date, 0 + key = key_id, "fake_ti", run_id, 0 assert executor.event_buffer[key][0] == State.SUCCESS assert executor.event_buffer[fail_key][0] == State.FAILED - expected = self.TEST_SUCCESS_COMMANDS + 1 if parallelism == 0 else parallelism - assert executor.workers_used == expected - - def test_execution_subprocess_unlimited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=0) - - def test_execution_subprocess_limited_parallelism(self): - with mock.patch.object( - settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", new_callable=mock.PropertyMock - ) as option: - option.return_value = True - self.execution_parallelism_subprocess(parallelism=2) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_unlimited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=0) - - @mock.patch.object(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", False) - def test_execution_limited_parallelism_fork(self): - self.execution_parallelism_fork(parallelism=2) + @skip_spawn_mp_start + @pytest.mark.parametrize( + ("parallelism", "fork_or_subproc"), + [ + pytest.param(0, True, id="unlimited_subprocess"), + pytest.param(2, True, id="limited_subprocess"), + pytest.param(0, False, id="unlimited_fork"), + pytest.param(2, False, id="limited_fork"), + ], + ) + def test_execution(self, parallelism: int, fork_or_subproc: bool, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(settings, "EXECUTE_TASKS_NEW_PYTHON_INTERPRETER", fork_or_subproc) + self._test_execute(parallelism=parallelism) @mock.patch("airflow.executors.local_executor.LocalExecutor.sync") @mock.patch("airflow.executors.base_executor.BaseExecutor.trigger_tasks") @@ -142,3 +138,20 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ), ] mock_stats_gauge.assert_has_calls(calls) + + @pytest.mark.execution_timeout(5) + def test_clean_stop_on_signal(self): + import signal + + executor = LocalExecutor(parallelism=2) + executor.start() + + # We want to ensure we start a worker process, as we now only create them on demand + executor._spawn_worker() + + try: + os.kill(os.getpid(), signal.SIGINT) + except KeyboardInterrupt: + pass + finally: + executor.end()