From af717e21b46234306117e9e8693e81b3a84cf60c Mon Sep 17 00:00:00 2001 From: Zohar Malamant Date: Mon, 11 Dec 2023 16:15:54 +0100 Subject: [PATCH] Implement MockDriver LocalDriver differs from the HPC drivers in that it is made of three parts that are run in sequence. `init` the subprocess, `wait` for the process to complete and `kill` when the user wants to cancel. Between the three parts the driver needs to send `JobEvent`s to the `Scheduler`. This commit implements a `MockDriver`, where the user can optionally specify a simplified version of each of `init`, `wait` or `kill`, depending on what they wish to do. --- src/ert/async_utils.py | 5 +- src/ert/scheduler/driver.py | 10 +- src/ert/scheduler/local_driver.py | 58 +++++++--- src/ert/scheduler/scheduler.py | 4 +- tests/unit_tests/scheduler/conftest.py | 30 ++++++ .../unit_tests/scheduler/test_local_driver.py | 74 +++++++++++++ tests/unit_tests/scheduler/test_scheduler.py | 102 ++++++++---------- 7 files changed, 199 insertions(+), 84 deletions(-) create mode 100644 tests/unit_tests/scheduler/conftest.py create mode 100644 tests/unit_tests/scheduler/test_local_driver.py diff --git a/src/ert/async_utils.py b/src/ert/async_utils.py index cbf319e5ab7..432213b227f 100644 --- a/src/ert/async_utils.py +++ b/src/ert/async_utils.py @@ -2,9 +2,7 @@ import asyncio import logging -import sys from contextlib import asynccontextmanager -from traceback import print_exception from typing import ( Any, AsyncGenerator, @@ -74,7 +72,6 @@ def _done_callback(task: asyncio.Task[_T_co]) -> None: if (exc := task.exception()) is None: return - print(f"Exception during {task.get_name()}", file=sys.stderr) - print_exception(exc, file=sys.stderr) + logger.error(f"Exception occurred during {task.get_name()}", exc_info=exc) except asyncio.CancelledError: pass diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 7c15c51fe9b..81dacd716cc 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -17,11 +17,13 @@ class Driver(ABC): """Adapter for the HPC cluster.""" def __init__(self) -> None: - self.event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None + self._event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None - async def ainit(self) -> None: - if self.event_queue is None: - self.event_queue = asyncio.Queue() + @property + def event_queue(self) -> asyncio.Queue[Tuple[int, JobEvent]]: + if self._event_queue is None: + self._event_queue = asyncio.Queue() + return self._event_queue @abstractmethod async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: diff --git a/src/ert/scheduler/local_driver.py b/src/ert/scheduler/local_driver.py index e88ae5453be..c4b9ae3bc40 100644 --- a/src/ert/scheduler/local_driver.py +++ b/src/ert/scheduler/local_driver.py @@ -2,10 +2,15 @@ import asyncio import os -from typing import MutableMapping +from asyncio.subprocess import Process +from typing import ( + MutableMapping, +) from ert.scheduler.driver import Driver, JobEvent +_TERMINATE_TIMEOUT = 10.0 + class LocalDriver(Driver): def __init__(self) -> None: @@ -15,7 +20,7 @@ def __init__(self) -> None: async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: await self.kill(iens) self._tasks[iens] = asyncio.create_task( - self._wait_until_finish(iens, executable, *args, cwd=cwd) + self._run(iens, executable, *args, cwd=cwd) ) async def kill(self, iens: int) -> None: @@ -29,29 +34,52 @@ async def kill(self, iens: int) -> None: async def finish(self) -> None: await asyncio.gather(*self._tasks.values()) - async def _wait_until_finish( + async def _run(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: + try: + proc = await self._init( + iens, + executable, + *args, + cwd=cwd, + ) + except Exception as exc: + print(f"{exc=}") + await self.event_queue.put((iens, JobEvent.FAILED)) + return + + await self.event_queue.put((iens, JobEvent.STARTED)) + try: + if await self._wait(proc): + await self.event_queue.put((iens, JobEvent.COMPLETED)) + else: + await self.event_queue.put((iens, JobEvent.FAILED)) + except asyncio.CancelledError: + await self._kill(proc) + await self.event_queue.put((iens, JobEvent.ABORTED)) + + async def _init( self, iens: int, executable: str, /, *args: str, cwd: str - ) -> None: - proc = await asyncio.create_subprocess_exec( + ) -> Process: + """This method exists to allow for mocking it in tests""" + return await asyncio.create_subprocess_exec( executable, *args, cwd=cwd, preexec_fn=os.setpgrp, ) - if self.event_queue is None: - await self.ainit() - assert self.event_queue is not None + async def _wait(self, proc: Process) -> bool: + """This method exists to allow for mocking it in tests""" + return await proc.wait() == 0 - await self.event_queue.put((iens, JobEvent.STARTED)) + async def _kill(self, proc: Process) -> None: + """This method exists to allow for mocking it in tests""" try: - if await proc.wait() == 0: - await self.event_queue.put((iens, JobEvent.COMPLETED)) - else: - await self.event_queue.put((iens, JobEvent.FAILED)) - except asyncio.CancelledError: proc.terminate() - await self.event_queue.put((iens, JobEvent.ABORTED)) + await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT) + except asyncio.TimeoutError: + proc.kill() + await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT) async def poll(self) -> None: """LocalDriver does not poll""" diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index d0a2bbc5388..a49f94d938c 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -59,9 +59,7 @@ async def ainit(self) -> None: if self._events is None: self._events = asyncio.Queue() - def add_realization( - self, real: Realization, callback_timeout: Callable[[int], None] - ) -> None: + def add_realization(self, real: Realization, callback_timeout: Any = None) -> None: self._jobs[real.iens] = Job(self, real) def kill_all_jobs(self) -> None: diff --git a/tests/unit_tests/scheduler/conftest.py b/tests/unit_tests/scheduler/conftest.py new file mode 100644 index 00000000000..370524af461 --- /dev/null +++ b/tests/unit_tests/scheduler/conftest.py @@ -0,0 +1,30 @@ +import pytest + +from ert.scheduler.local_driver import LocalDriver + + +class MockDriver(LocalDriver): + def __init__(self, init=None, wait=None, kill=None): + super().__init__() + self._mock_init = init + self._mock_wait = wait + self._mock_kill = kill + + async def _init(self, *args, **kwargs): + if self._mock_init is not None: + await self._mock_init(*args, **kwargs) + + async def _wait(self, *args): + if self._mock_wait is not None: + result = await self._mock_wait() + return True if result is None else bool(result) + return True + + async def _kill(self, *args): + if self._mock_kill is not None: + await self._mock_kill() + + +@pytest.fixture +def mock_driver(): + return MockDriver diff --git a/tests/unit_tests/scheduler/test_local_driver.py b/tests/unit_tests/scheduler/test_local_driver.py new file mode 100644 index 00000000000..5758bf6439d --- /dev/null +++ b/tests/unit_tests/scheduler/test_local_driver.py @@ -0,0 +1,74 @@ +import asyncio + +import pytest + +from ert.scheduler import local_driver +from ert.scheduler.driver import JobEvent +from ert.scheduler.local_driver import LocalDriver + + +async def test_success(tmp_path): + driver = LocalDriver() + + await driver.submit(42, "/usr/bin/env", "touch", "testfile", cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.STARTED) + assert await driver.event_queue.get() == (42, JobEvent.COMPLETED) + + assert (tmp_path / "testfile").exists() + + +async def test_failure(tmp_path): + driver = LocalDriver() + + await driver.submit(42, "/usr/bin/env", "false", cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.STARTED) + assert await driver.event_queue.get() == (42, JobEvent.FAILED) + + +async def test_file_not_found(tmp_path): + driver = LocalDriver() + + await driver.submit(42, "/file/not/found", cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.FAILED) + + +async def test_kill(tmp_path): + driver = LocalDriver() + + await driver.submit(42, "/usr/bin/env", "sleep", "10", cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.STARTED) + await driver.kill(42) + assert await driver.event_queue.get() == (42, JobEvent.ABORTED) + + +@pytest.mark.timeout(5) +async def test_kill_unresponsive_process(monkeypatch, tmp_path): + # Reduce timeout to something more appropriate for a test + monkeypatch.setattr(local_driver, "_TERMINATE_TIMEOUT", 0.1) + + (tmp_path / "script").write_text( + """\ + trap "" 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + sleep 60 + """ + ) + + driver = LocalDriver() + + await driver.submit(42, "/bin/sh", tmp_path / "script", cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.STARTED) + await driver.kill(42) + assert await driver.event_queue.get() == (42, JobEvent.ABORTED) + + +@pytest.mark.parametrize( + "cmd,event", [("true", JobEvent.COMPLETED), ("false", JobEvent.FAILED)] +) +async def test_kill_when_job_completed(tmp_path, cmd, event): + driver = LocalDriver() + + await driver.submit(42, "/usr/bin/env", cmd, cwd=tmp_path) + assert await driver.event_queue.get() == (42, JobEvent.STARTED) + await asyncio.sleep(0.5) + await driver.kill(42) + assert await driver.event_queue.get() == (42, event) diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 45ec2cc2aaf..ea9da7b3b7f 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -6,43 +6,12 @@ import pytest -from ert.config.forward_model import ForwardModel from ert.ensemble_evaluator._builder._realization import Realization from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED from ert.run_arg import RunArg from ert.scheduler import scheduler -def create_bash_step(script: str) -> ForwardModel: - return ForwardModel( - name="bash_step", - executable="/usr/bin/env", - arglist=["bash", "-c", script], - ) - - -def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None: - jobs = { - "global_environment": {}, - "config_path": "/dev/null", - "config_file": "/dev/null", - "jobList": [ - { - "name": step.name, - "executable": step.executable, - "argList": step.arglist, - } - for step in steps - ], - "run_id": "0", - "ert_pid": "0", - "real_id": "0", - } - - with open(path / "jobs.json", "w") as f: - json.dump(jobs, f) - - @pytest.fixture def realization(storage, tmp_path): ensemble = storage.create_experiment().create_ensemble(name="foo", ensemble_size=1) @@ -72,41 +41,51 @@ async def test_empty(): assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED -async def test_single_job(tmp_path: Path, realization): - step = create_bash_step("echo 'Hello, world!' > testfile") - realization.forward_models = [step] +async def test_single_job(realization, mock_driver): + future = asyncio.Future() - sch = scheduler.Scheduler() - sch.add_realization(realization, callback_timeout=lambda _: None) + async def init(iens, *args, **kwargs): + future.set_result(iens) - create_jobs_json(tmp_path, [step]) - sch.add_dispatch_information_to_jobs_file() + driver = mock_driver(init=init) + + sch = scheduler.Scheduler(driver) + sch.add_realization(realization) assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED - assert (tmp_path / "testfile").read_text() == "Hello, world!\n" + assert await future == realization.iens -async def test_cancel(tmp_path: Path, realization): - step = create_bash_step("touch a; sleep 10; touch b") - realization.forward_models = [step] +async def test_cancel(realization, mock_driver): + pre = asyncio.Event() + post = asyncio.Event() + killed = False - sch = scheduler.Scheduler() - sch.add_realization(realization, callback_timeout=lambda _: None) + async def wait(): + pre.set() + await asyncio.sleep(10) + post.set() + + async def kill(): + nonlocal killed + killed = True - create_jobs_json(tmp_path, [step]) - sch.add_dispatch_information_to_jobs_file() + driver = mock_driver(wait=wait, kill=kill) + sch = scheduler.Scheduler(driver) + sch.add_realization(realization) scheduler_task = asyncio.create_task(sch.execute()) - # Wait for the job to start (i.e. let the file "a" be touched) - await asyncio.sleep(1) + # Wait for the job to start + await asyncio.wait_for(pre.wait(), timeout=1) # Kill all jobs and wait for the scheduler to complete sch.kill_all_jobs() await scheduler_task - assert (tmp_path / "a").exists() - assert not (tmp_path / "b").exists() + assert pre.is_set() + assert not post.is_set() + assert killed @pytest.mark.parametrize( @@ -117,14 +96,21 @@ async def test_cancel(tmp_path: Path, realization): (3), ], ) -async def test_that_max_submit_was_reached(tmp_path: Path, realization, max_submit): - script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1" - step = create_bash_step(script) - realization.forward_models = [step] - sch = scheduler.Scheduler() +async def test_that_max_submit_was_reached(realization, max_submit, mock_driver): + retries = 0 + + async def init(*args, **kwargs): + nonlocal retries + retries += 1 + + async def wait(): + return False + + driver = mock_driver(init=init, wait=wait) + sch = scheduler.Scheduler(driver) + sch._max_submit = max_submit sch.add_realization(realization, callback_timeout=lambda _: None) - create_jobs_json(tmp_path, [step]) - sch.add_dispatch_information_to_jobs_file() + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED - assert (tmp_path / "cnt").read_text() == f"{max_submit}\n" + assert retries == max_submit