Skip to content

Commit

Permalink
Add SIGHUP signal before SIGTERM for task runner termination
Browse files Browse the repository at this point in the history
  • Loading branch information
molcay committed Aug 7, 2024
1 parent b014077 commit 643d8bc
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 17 deletions.
29 changes: 29 additions & 0 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def sigusr2_debug_handler(signum, frame):
return_code = None
try:
self.task_runner.start()
self.task_started_at = timezone.utcnow()
self.task_execution_timeout = (
self.task_instance.task.execution_timeout if self.task_instance.task is not None else None
)
local_task_job_heartbeat_sec = conf.getint("scheduler", "local_task_job_heartbeat_sec")
if local_task_job_heartbeat_sec < 1:
heartbeat_time_limit = conf.getint("scheduler", "scheduler_zombie_task_threshold")
Expand Down Expand Up @@ -234,6 +238,31 @@ def sigusr2_debug_handler(signum, frame):
f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit "
f"({heartbeat_time_limit}s)."
)

if self.task_execution_timeout is not None:
# If the time elapsed is longer than the execution timeout,
# then we need to do terminate the process
elapsed_time = timezone.utcnow() - self.task_started_at
timed_out = elapsed_time > self.task_execution_timeout
if timed_out:
Stats.incr("local_task_job_execution_timeout", 1, 1)
self.log.error(
"The task (#%s) timed out after %s seconds! Terminating...",
self.task_instance.task_id,
self.task_execution_timeout.total_seconds(),
)
if span.is_recording():
span.add_event(
name="error",
attributes={
"message": "Task timeout",
"task_execution_timeout(s)": self.task_execution_timeout.total_seconds(),
"elapsed_time(s)": elapsed_time.total_seconds(),
},
)
self.task_runner.terminate()
break

return return_code
finally:
# Print a marker for log grouping of details before task execution
Expand Down
45 changes: 33 additions & 12 deletions airflow/utils/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@

log = logging.getLogger(__name__)

# When killing processes, time to wait after issuing a SIGTERM before issuing a
# SIGKILL.
DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint("core", "KILLED_TASK_CLEANUP_TIME")
# When killing processes, time to wait after issuing a signal (e.g: SIGHUP, SIGTERM) before issuing a
# other signal (e.g: SIGTERM, SIGKILL).
# Possible WAIT scenarios between signals: SIGHUP -> WAIT -> SIGTERM -> WAIT -> SIGKILL
DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS = conf.getint("core", "KILLED_TASK_CLEANUP_TIME")


def reap_process_group(
process_group_id: int,
logger,
sig: signal.Signals = signal.SIGTERM,
timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM,
timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS,
) -> dict[int, int]:
"""
Send sig (SIGTERM) to the process group of pid.
Send sig (SIGHUP) to the process group of pid.
Tries really hard to terminate all processes in the group (including grandchildren). Will send
sig (SIGTERM) to the process group of pid. If any process is alive after timeout
a SIGKILL will be send.
sig (SIGHUP) to the process group of pid. If any process is alive after timeout
a SIGTERM will be sent. If any process is still alive then a SIGKILL will be sent.
:param process_group_id: process group id to kill.
The process that wants to create the group should run
Expand All @@ -71,11 +71,14 @@ def reap_process_group(
"root" of the group has pid = gid and all other processes in the group have different
pids but the same gid (equal the pid of the root process)
:param logger: log handler
:param sig: signal type
:param timeout: how much time a process has to terminate
"""
returncodes = {}

def on_hangup(p):
logger.info("Process %s (%s) hung up with exit code %s", p, p.pid, p.returncode)
returncodes[p.id] = p.returncode

def on_terminate(p):
logger.info("Process %s (%s) terminated with exit code %s", p, p.pid, p.returncode)
returncodes[p.pid] = p.returncode
Expand Down Expand Up @@ -129,21 +132,39 @@ def signal_procs(sig):
except OSError:
pass

# try SIGHUP
logger.info(
"Sending %s to group %s. PIDs of all processes in the group: %s",
sig,
signal.SIGHUP,
process_group_id,
[p.pid for p in all_processes_in_the_group],
)
try:
signal_procs(sig)
signal_procs(signal.SIGHUP)
except OSError as err:
# No such process, which means there is no such process group - our job
# is done
if err.errno == errno.ESRCH:
return returncodes

_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate)
_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_hangup)

if alive: # try SIGTERM
logger.info(
"Sending %s to group %s. PIDs of all processes in the group: %s",
signal.SIGTERM,
process_group_id,
[p.pid for p in all_processes_in_the_group],
)
try:
signal_procs(signal.SIGTERM)
except OSError as err:
# No such process, which means there is no such process group - our job
# is done
if err.errno == errno.ESRCH:
return returncodes

_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate)

if alive:
for proc in alive:
Expand Down
4 changes: 3 additions & 1 deletion tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, capl
)
lm.clear()

@pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL])
@pytest.mark.parametrize("signal_type", [signal.SIGHUP, signal.SIGTERM, signal.SIGKILL])
def test_process_os_signal_calls_on_failure_callback(
self, monkeypatch, tmp_path, get_test_dag, signal_type
):
Expand Down Expand Up @@ -1042,6 +1042,8 @@ def test_process_sigterm_works_with_retries(self, mp_method, wait_timeout, daemo
with timeout(wait_timeout, "Timeout during waiting start LocalTaskJob"):
while task_started.value == 0:
time.sleep(0.2)

os.kill(proc.pid, signal.SIGHUP)
os.kill(proc.pid, signal.SIGTERM)

with timeout(wait_timeout, "Timeout during waiting callback"):
Expand Down
13 changes: 9 additions & 4 deletions tests/utils/test_process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,28 @@

class TestReapProcessGroup:
@staticmethod
def _ignores_sigterm(child_pid, child_setup_done):
def _ignores_sighup_and_sigterm(child_pid, child_setup_done):
def signal_handler(unused_signum, unused_frame):
pass

signal.signal(signal.SIGHUP, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
child_pid.value = os.getpid()
child_setup_done.release()
while True:
time.sleep(1)

@staticmethod
def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done):
def _parent_of_ignores_sighup_and_sigterm(parent_pid, child_pid, setup_done):
def signal_handler(unused_signum, unused_frame):
pass

os.setsid()
signal.signal(signal.SIGHUP, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
child_setup_done = multiprocessing.Semaphore(0)
child = multiprocessing.Process(
target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done]
target=TestReapProcessGroup._ignores_sighup_and_sigterm, args=[child_pid, child_setup_done]
)
child.start()
child_setup_done.acquire(timeout=5.0)
Expand All @@ -80,7 +82,10 @@ def test_reap_process_group(self):
parent_pid = multiprocessing.Value("i", 0)
child_pid = multiprocessing.Value("i", 0)
args = [parent_pid, child_pid, parent_setup_done]
parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args)
parent = multiprocessing.Process(
target=TestReapProcessGroup._parent_of_ignores_sighup_and_sigterm,
args=args,
)
try:
parent.start()
assert parent_setup_done.acquire(timeout=5.0)
Expand Down

0 comments on commit 643d8bc

Please sign in to comment.