Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SIGHUP signal before SIGTERM for task runner termination #88

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions airflow/jobs/local_task_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def sigusr2_debug_handler(signum, frame):
return_code = None
try:
self.task_runner.start()
self.task_started_at = timezone.utcnow()
self.task_execution_timeout = (
self.task_instance.task.execution_timeout if self.task_instance.task is not None else None
)
local_task_job_heartbeat_sec = conf.getint("scheduler", "local_task_job_heartbeat_sec")
if local_task_job_heartbeat_sec < 1:
heartbeat_time_limit = conf.getint("scheduler", "scheduler_zombie_task_threshold")
Expand Down Expand Up @@ -234,6 +238,31 @@ def sigusr2_debug_handler(signum, frame):
f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit "
f"({heartbeat_time_limit}s)."
)

if self.task_execution_timeout is not None:
# If the time elapsed is longer than the execution timeout,
# then we need to do terminate the process
elapsed_time = timezone.utcnow() - self.task_started_at
timed_out = elapsed_time > self.task_execution_timeout
if timed_out:
Stats.incr("local_task_job_execution_timeout", 1, 1)
self.log.error(
"The task (#%s) timed out after %s seconds! Terminating...",
self.task_instance.task_id,
self.task_execution_timeout.total_seconds(),
)
if span.is_recording():
span.add_event(
name="error",
attributes={
"message": "Task timeout",
"task_execution_timeout(s)": self.task_execution_timeout.total_seconds(),
"elapsed_time(s)": elapsed_time.total_seconds(),
},
)
self.task_runner.terminate()
break

return return_code
finally:
# Print a marker for log grouping of details before task execution
Expand Down
45 changes: 33 additions & 12 deletions airflow/utils/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,23 @@

log = logging.getLogger(__name__)

# When killing processes, time to wait after issuing a SIGTERM before issuing a
# SIGKILL.
DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM = conf.getint("core", "KILLED_TASK_CLEANUP_TIME")
# When killing processes, time to wait after issuing a signal (e.g: SIGHUP, SIGTERM) before issuing a
# other signal (e.g: SIGTERM, SIGKILL).
# Possible WAIT scenarios between signals: SIGHUP -> WAIT -> SIGTERM -> WAIT -> SIGKILL
DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS = conf.getint("core", "KILLED_TASK_CLEANUP_TIME")


def reap_process_group(
process_group_id: int,
logger,
sig: signal.Signals = signal.SIGTERM,
timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGTERM,
timeout: int = DEFAULT_TIME_TO_WAIT_AFTER_SIGNALS,
) -> dict[int, int]:
"""
Send sig (SIGTERM) to the process group of pid.
Send sig (SIGHUP) to the process group of pid.

Tries really hard to terminate all processes in the group (including grandchildren). Will send
sig (SIGTERM) to the process group of pid. If any process is alive after timeout
a SIGKILL will be send.
sig (SIGHUP) to the process group of pid. If any process is alive after timeout
a SIGTERM will be sent. If any process is still alive then a SIGKILL will be sent.

:param process_group_id: process group id to kill.
The process that wants to create the group should run
Expand All @@ -71,11 +71,14 @@ def reap_process_group(
"root" of the group has pid = gid and all other processes in the group have different
pids but the same gid (equal the pid of the root process)
:param logger: log handler
:param sig: signal type
:param timeout: how much time a process has to terminate
"""
returncodes = {}

def on_hangup(p):
logger.info("Process %s (%s) hung up with exit code %s", p, p.pid, p.returncode)
returncodes[p.id] = p.returncode

def on_terminate(p):
logger.info("Process %s (%s) terminated with exit code %s", p, p.pid, p.returncode)
returncodes[p.pid] = p.returncode
Expand Down Expand Up @@ -129,21 +132,39 @@ def signal_procs(sig):
except OSError:
pass

# try SIGHUP
logger.info(
"Sending %s to group %s. PIDs of all processes in the group: %s",
sig,
signal.SIGHUP,
process_group_id,
[p.pid for p in all_processes_in_the_group],
)
try:
signal_procs(sig)
signal_procs(signal.SIGHUP)
except OSError as err:
# No such process, which means there is no such process group - our job
# is done
if err.errno == errno.ESRCH:
return returncodes

_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate)
_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_hangup)

if alive: # try SIGTERM
logger.info(
"Sending %s to group %s. PIDs of all processes in the group: %s",
signal.SIGTERM,
process_group_id,
[p.pid for p in all_processes_in_the_group],
)
try:
signal_procs(signal.SIGTERM)
except OSError as err:
# No such process, which means there is no such process group - our job
# is done
if err.errno == errno.ESRCH:
return returncodes

_, alive = psutil.wait_procs(all_processes_in_the_group, timeout=timeout, callback=on_terminate)

if alive:
for proc in alive:
Expand Down
4 changes: 3 additions & 1 deletion tests/jobs/test_local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def test_success_slow_task_not_killed_by_overtime_but_regular_timeout(self, capl
)
lm.clear()

@pytest.mark.parametrize("signal_type", [signal.SIGTERM, signal.SIGKILL])
@pytest.mark.parametrize("signal_type", [signal.SIGHUP, signal.SIGTERM, signal.SIGKILL])
def test_process_os_signal_calls_on_failure_callback(
self, monkeypatch, tmp_path, get_test_dag, signal_type
):
Expand Down Expand Up @@ -1042,6 +1042,8 @@ def test_process_sigterm_works_with_retries(self, mp_method, wait_timeout, daemo
with timeout(wait_timeout, "Timeout during waiting start LocalTaskJob"):
while task_started.value == 0:
time.sleep(0.2)

os.kill(proc.pid, signal.SIGHUP)
os.kill(proc.pid, signal.SIGTERM)

with timeout(wait_timeout, "Timeout during waiting callback"):
Expand Down
13 changes: 9 additions & 4 deletions tests/utils/test_process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,26 +43,28 @@

class TestReapProcessGroup:
@staticmethod
def _ignores_sigterm(child_pid, child_setup_done):
def _ignores_sighup_and_sigterm(child_pid, child_setup_done):
def signal_handler(unused_signum, unused_frame):
pass

signal.signal(signal.SIGHUP, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
child_pid.value = os.getpid()
child_setup_done.release()
while True:
time.sleep(1)

@staticmethod
def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done):
def _parent_of_ignores_sighup_and_sigterm(parent_pid, child_pid, setup_done):
def signal_handler(unused_signum, unused_frame):
pass

os.setsid()
signal.signal(signal.SIGHUP, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
child_setup_done = multiprocessing.Semaphore(0)
child = multiprocessing.Process(
target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done]
target=TestReapProcessGroup._ignores_sighup_and_sigterm, args=[child_pid, child_setup_done]
)
child.start()
child_setup_done.acquire(timeout=5.0)
Expand All @@ -80,7 +82,10 @@ def test_reap_process_group(self):
parent_pid = multiprocessing.Value("i", 0)
child_pid = multiprocessing.Value("i", 0)
args = [parent_pid, child_pid, parent_setup_done]
parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args)
parent = multiprocessing.Process(
target=TestReapProcessGroup._parent_of_ignores_sighup_and_sigterm,
args=args,
)
try:
parent.start()
assert parent_setup_done.acquire(timeout=5.0)
Expand Down
Loading