diff --git a/airflow/jobs/local_task_job_runner.py b/airflow/jobs/local_task_job_runner.py index a6a1f0ac8fa23..77e7739e62e1c 100644 --- a/airflow/jobs/local_task_job_runner.py +++ b/airflow/jobs/local_task_job_runner.py @@ -169,6 +169,10 @@ def sigusr2_debug_handler(signum, frame): return_code = None try: self.task_runner.start() + self.task_started_at = timezone.utcnow() + self.task_execution_timeout = ( + self.task_instance.task.execution_timeout if self.task_instance.task is not None else None + ) local_task_job_heartbeat_sec = conf.getint("scheduler", "local_task_job_heartbeat_sec") if local_task_job_heartbeat_sec < 1: heartbeat_time_limit = conf.getint("scheduler", "scheduler_zombie_task_threshold") @@ -234,6 +238,31 @@ def sigusr2_debug_handler(signum, frame): f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit " f"({heartbeat_time_limit}s)." ) + + if self.task_execution_timeout is not None: + # If the time elapsed is longer than the execution timeout, + # then we need to do terminate the process + elapsed_time = timezone.utcnow() - self.task_started_at + timed_out = elapsed_time > self.task_execution_timeout + if timed_out: + Stats.incr("local_task_job_execution_timeout", 1, 1) + self.log.error( + "The task (#%s) timed out after %s seconds! Terminating...", + self.task_instance.task_id, + self.task_execution_timeout.total_seconds(), + ) + if span.is_recording(): + span.add_event( + name="error", + attributes={ + "message": "Task timeout", + "task_execution_timeout(s)": self.task_execution_timeout.total_seconds(), + "elapsed_time(s)": elapsed_time.total_seconds(), + }, + ) + self.task_runner.terminate() + break + return return_code finally: # Print a marker for log grouping of details before task execution diff --git a/airflow/utils/process_utils.py b/airflow/utils/process_utils.py index e8b3e8840da86..527e71d089e10 100644 --- a/airflow/utils/process_utils.py +++ b/airflow/utils/process_utils.py @@ -46,23 +46,23 @@ log = logging.getLogger(__name__) -# When killing processes, time to wait after issuing a SIGTERM before issuing a -# SIGKILL. -DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint("core", "KILLED_TASK_CLEANUP_TIME") +# When killing processes, time to wait after issuing a signal (e.g: SIGHUP, SIGTERM) before issuing a +# other signal (e.g: SIGTERM, SIGKILL). +# Possible WAIT scenarios between signals: SIGHUP -> WAIT -> SIGTERM -> WAIT -> SIGKILL +DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS = conf.getint("core", "KILLED_TASK_CLEANUP_TIME") def reap_process_group( process_group_id: int, logger, - sig: signal.Signals = signal.SIGTERM, - timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM, + timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS, ) -> dict[int, int]: """ - Send sig (SIGTERM) to the process group of pid. + Send sig (SIGHUP) to the process group of pid. Tries really hard to terminate all processes in the group (including grandchildren). Will send - sig (SIGTERM) to the process group of pid. If any process is alive after timeout - a SIGKILL will be send. + sig (SIGHUP) to the process group of pid. If any process is alive after timeout + a SIGTERM will be sent. If any process is still alive then a SIGKILL will be sent. :param process_group_id: process group id to kill. The process that wants to create the group should run @@ -71,11 +71,14 @@ def reap_process_group( "root" of the group has pid = gid and all other processes in the group have different pids but the same gid (equal the pid of the root process) :param logger: log handler - :param sig: signal type :param timeout: how much time a process has to terminate """ returncodes = {} + def on_hangup(p): + logger.info("Process %s (%s) hung up with exit code %s", p, p.pid, p.returncode) + returncodes[p.id] = p.returncode + def on_terminate(p): logger.info("Process %s (%s) terminated with exit code %s", p, p.pid, p.returncode) returncodes[p.pid] = p.returncode @@ -129,21 +132,39 @@ def signal_procs(sig): except OSError: pass + # try SIGHUP logger.info( "Sending %s to group %s. PIDs of all processes in the group: %s", - sig, + signal.SIGHUP, process_group_id, [p.pid for p in all_processes_in_the_group], ) try: - signal_procs(sig) + signal_procs(signal.SIGHUP) except OSError as err: # No such process, which means there is no such process group - our job # is done if err.errno == errno.ESRCH: return returncodes - _, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate) + _, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_hangup) + + if alive: # try SIGTERM + logger.info( + "Sending %s to group %s. PIDs of all processes in the group: %s", + signal.SIGTERM, + process_group_id, + [p.pid for p in all_processes_in_the_group], + ) + try: + signal_procs(signal.SIGTERM) + except OSError as err: + # No such process, which means there is no such process group - our job + # is done + if err.errno == errno.ESRCH: + return returncodes + + _, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate) if alive: for proc in alive: diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index aefb77997e517..17101fd623e9f 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -698,7 +698,7 @@ def test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, capl ) lm.clear() - @pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL]) + @pytest.mark.parametrize("signal_type", [signal.SIGHUP, signal.SIGTERM, signal.SIGKILL]) def test_process_os_signal_calls_on_failure_callback( self, monkeypatch, tmp_path, get_test_dag, signal_type ): @@ -1042,6 +1042,8 @@ def test_process_sigterm_works_with_retries(self, mp_method, wait_timeout, daemo with timeout(wait_timeout, "Timeout during waiting start LocalTaskJob"): while task_started.value == 0: time.sleep(0.2) + + os.kill(proc.pid, signal.SIGHUP) os.kill(proc.pid, signal.SIGTERM) with timeout(wait_timeout, "Timeout during waiting callback"): diff --git a/tests/utils/test_process_utils.py b/tests/utils/test_process_utils.py index ac591248ae49b..eb9cc245e05cb 100644 --- a/tests/utils/test_process_utils.py +++ b/tests/utils/test_process_utils.py @@ -43,10 +43,11 @@ class TestReapProcessGroup: @staticmethod - def _ignores_sigterm(child_pid, child_setup_done): + def _ignores_sighup_and_sigterm(child_pid, child_setup_done): def signal_handler(unused_signum, unused_frame): pass + signal.signal(signal.SIGHUP, signal_handler) signal.signal(signal.SIGTERM, signal_handler) child_pid.value = os.getpid() child_setup_done.release() @@ -54,15 +55,16 @@ def signal_handler(unused_signum, unused_frame): time.sleep(1) @staticmethod - def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done): + def _parent_of_ignores_sighup_and_sigterm(parent_pid, child_pid, setup_done): def signal_handler(unused_signum, unused_frame): pass os.setsid() + signal.signal(signal.SIGHUP, signal_handler) signal.signal(signal.SIGTERM, signal_handler) child_setup_done = multiprocessing.Semaphore(0) child = multiprocessing.Process( - target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done] + target=TestReapProcessGroup._ignores_sighup_and_sigterm, args=[child_pid, child_setup_done] ) child.start() child_setup_done.acquire(timeout=5.0) @@ -80,7 +82,10 @@ def test_reap_process_group(self): parent_pid = multiprocessing.Value("i", 0) child_pid = multiprocessing.Value("i", 0) args = [parent_pid, child_pid, parent_setup_done] - parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args) + parent = multiprocessing.Process( + target=TestReapProcessGroup._parent_of_ignores_sighup_and_sigterm, + args=args, + ) try: parent.start() assert parent_setup_done.acquire(timeout=5.0)