diff --git a/src/ert/services/_storage_main.py b/src/ert/services/_storage_main.py index 487e6870e12..cfa83049be5 100644 --- a/src/ert/services/_storage_main.py +++ b/src/ert/services/_storage_main.py @@ -7,7 +7,8 @@ import signal import socket import string -import sys +import threading +import time import warnings from typing import Any, Dict, List, Optional, Union @@ -123,30 +124,23 @@ def run_server(args: Optional[argparse.Namespace] = None, debug: bool = False) - server.run(sockets=[sock]) -def terminate_on_parent_death() -> None: - """Quit the server when the parent does a SIGABRT or is otherwise destroyed. - This functionality has existed on Linux for a good while, but it isn't - exposed in the Python standard library. Use ctypes to hook into the - functionality. +def terminate_on_parent_death( + stopped: threading.Event, poll_interval: float = 1.0 +) -> None: + """ + Quit the server when the parent process is no longer running. """ - if sys.platform != "linux" or "ERT_COMM_FD" not in os.environ: - return - - from ctypes import CDLL, c_int, c_ulong # noqa: PLC0415 - - lib = CDLL(None) - # from - # int prctl(int option, ...) - prctl = lib.prctl - prctl.restype = c_int - prctl.argtypes = (c_int, c_ulong) + def check_parent_alive() -> bool: + return os.getppid() != 1 - # from - PR_SET_PDEATHSIG = 1 + while check_parent_alive(): + if stopped.is_set(): + return + time.sleep(poll_interval) - # connect parent death signal to our SIGTERM - prctl(PR_SET_PDEATHSIG, signal.SIGTERM) + # Parent is no longer alive, terminate this process. + os.kill(os.getpid(), signal.SIGTERM) if __name__ == "__main__": @@ -156,6 +150,14 @@ def terminate_on_parent_death() -> None: warnings.filterwarnings("ignore", category=DeprecationWarning) uvicorn.config.LOGGING_CONFIG.clear() uvicorn.config.LOGGING_CONFIG.update(logging_conf) - terminate_on_parent_death() + _stopped = threading.Event() + terminate_on_parent_death_thread = threading.Thread( + target=terminate_on_parent_death, args=[_stopped, 1.0] + ) with ErtPluginContext(logger=logging.getLogger()) as context: - run_server(debug=False) + terminate_on_parent_death_thread.start() + try: + run_server(debug=False) + finally: + _stopped.set() + terminate_on_parent_death_thread.join() diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index 487e2fcade4..fb775e6a886 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -1,3 +1,4 @@ +import asyncio import fileinput import json import logging @@ -13,6 +14,7 @@ import pytest import websockets.exceptions import xtgeo +from psutil import NoSuchProcess, Popen, Process, ZombieProcess from resdata.summary import Summary import _ert.threading @@ -972,3 +974,43 @@ def raise_connection_error(*args, **kwargs): ENSEMBLE_EXPERIMENT_MODE, "poly.ert", ) + + +@pytest.mark.usefixtures("copy_poly_case") +async def test_that_killed_ert_does_not_leave_storage_server_process(): + ert_subprocess = Popen(["ert", "gui", "poly.ert"]) + assert ert_subprocess.is_running() + + async def _find_storage_process_pid() -> int: + while True: + for ert_child_process in ert_subprocess.children(): + try: + if "storage" in "".join(ert_child_process.cmdline()): + return ert_child_process.pid + except (ZombieProcess, NoSuchProcess): + pass + await asyncio.sleep(0.05) + + storage_process_pid = await asyncio.wait_for( + _find_storage_process_pid(), timeout=120 + ) + # wait for storage server to have connected to ert + await asyncio.sleep(5) + storage_process = Process(storage_process_pid) + + assert ert_subprocess.is_running() + assert storage_process.is_running() + kill_ert_subprocess = await asyncio.create_subprocess_exec( + "kill", "-9", f"{ert_subprocess.pid}" + ) + await kill_ert_subprocess.wait() + + async def _wait_for_storage_process_to_shut_down(): + storage_server_has_shutdown = asyncio.Event() + while not storage_server_has_shutdown.is_set(): + if not storage_process.is_running(): + storage_server_has_shutdown.set() + await asyncio.sleep(0.1) + + await asyncio.wait_for(_wait_for_storage_process_to_shut_down(), timeout=45) + assert not storage_process.is_running()