Skip to content

Commit

Permalink
Implement MockDriver
Browse files Browse the repository at this point in the history
LocalDriver differs from the HPC drivers in that it is made of three
parts that are run in sequence. `init` the subprocess, `wait` for the
process to complete and `kill` when the user wants to cancel. Between
the three parts the driver needs to send `JobEvent`s to the `Scheduler`.

This commit implements a `MockDriver`, where the user can optionally
specify a simplified version of each of `init`, `wait` or `kill`,
depending on what they wish to do.
  • Loading branch information
pinkwah committed Dec 18, 2023
1 parent e867409 commit af717e2
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 84 deletions.
5 changes: 1 addition & 4 deletions src/ert/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import asyncio
import logging
import sys
from contextlib import asynccontextmanager
from traceback import print_exception
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -74,7 +72,6 @@ def _done_callback(task: asyncio.Task[_T_co]) -> None:
if (exc := task.exception()) is None:
return

print(f"Exception during {task.get_name()}", file=sys.stderr)
print_exception(exc, file=sys.stderr)
logger.error(f"Exception occurred during {task.get_name()}", exc_info=exc)
except asyncio.CancelledError:
pass
10 changes: 6 additions & 4 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ class Driver(ABC):
"""Adapter for the HPC cluster."""

def __init__(self) -> None:
self.event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None
self._event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None

async def ainit(self) -> None:
if self.event_queue is None:
self.event_queue = asyncio.Queue()
@property
def event_queue(self) -> asyncio.Queue[Tuple[int, JobEvent]]:
if self._event_queue is None:
self._event_queue = asyncio.Queue()
return self._event_queue

@abstractmethod
async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
Expand Down
58 changes: 43 additions & 15 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import asyncio
import os
from typing import MutableMapping
from asyncio.subprocess import Process
from typing import (
MutableMapping,
)

from ert.scheduler.driver import Driver, JobEvent

_TERMINATE_TIMEOUT = 10.0


class LocalDriver(Driver):
def __init__(self) -> None:
Expand All @@ -15,7 +20,7 @@ def __init__(self) -> None:
async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
await self.kill(iens)
self._tasks[iens] = asyncio.create_task(
self._wait_until_finish(iens, executable, *args, cwd=cwd)
self._run(iens, executable, *args, cwd=cwd)
)

async def kill(self, iens: int) -> None:
Expand All @@ -29,29 +34,52 @@ async def kill(self, iens: int) -> None:
async def finish(self) -> None:
await asyncio.gather(*self._tasks.values())

async def _wait_until_finish(
async def _run(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
try:
proc = await self._init(
iens,
executable,
*args,
cwd=cwd,
)
except Exception as exc:
print(f"{exc=}")
await self.event_queue.put((iens, JobEvent.FAILED))
return

await self.event_queue.put((iens, JobEvent.STARTED))
try:
if await self._wait(proc):
await self.event_queue.put((iens, JobEvent.COMPLETED))
else:
await self.event_queue.put((iens, JobEvent.FAILED))
except asyncio.CancelledError:
await self._kill(proc)
await self.event_queue.put((iens, JobEvent.ABORTED))

async def _init(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
proc = await asyncio.create_subprocess_exec(
) -> Process:
"""This method exists to allow for mocking it in tests"""
return await asyncio.create_subprocess_exec(
executable,
*args,
cwd=cwd,
preexec_fn=os.setpgrp,
)

if self.event_queue is None:
await self.ainit()
assert self.event_queue is not None
async def _wait(self, proc: Process) -> bool:
"""This method exists to allow for mocking it in tests"""
return await proc.wait() == 0

await self.event_queue.put((iens, JobEvent.STARTED))
async def _kill(self, proc: Process) -> None:
"""This method exists to allow for mocking it in tests"""
try:
if await proc.wait() == 0:
await self.event_queue.put((iens, JobEvent.COMPLETED))
else:
await self.event_queue.put((iens, JobEvent.FAILED))
except asyncio.CancelledError:
proc.terminate()
await self.event_queue.put((iens, JobEvent.ABORTED))
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)
except asyncio.TimeoutError:
proc.kill()
await asyncio.wait_for(proc.wait(), _TERMINATE_TIMEOUT)

async def poll(self) -> None:
"""LocalDriver does not poll"""
4 changes: 1 addition & 3 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ async def ainit(self) -> None:
if self._events is None:
self._events = asyncio.Queue()

def add_realization(
self, real: Realization, callback_timeout: Callable[[int], None]
) -> None:
def add_realization(self, real: Realization, callback_timeout: Any = None) -> None:
self._jobs[real.iens] = Job(self, real)

def kill_all_jobs(self) -> None:
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/scheduler/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from ert.scheduler.local_driver import LocalDriver


class MockDriver(LocalDriver):
def __init__(self, init=None, wait=None, kill=None):
super().__init__()
self._mock_init = init
self._mock_wait = wait
self._mock_kill = kill

async def _init(self, *args, **kwargs):
if self._mock_init is not None:
await self._mock_init(*args, **kwargs)

async def _wait(self, *args):
if self._mock_wait is not None:
result = await self._mock_wait()
return True if result is None else bool(result)
return True

async def _kill(self, *args):
if self._mock_kill is not None:
await self._mock_kill()


@pytest.fixture
def mock_driver():
return MockDriver
74 changes: 74 additions & 0 deletions tests/unit_tests/scheduler/test_local_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import asyncio

import pytest

from ert.scheduler import local_driver
from ert.scheduler.driver import JobEvent
from ert.scheduler.local_driver import LocalDriver


async def test_success(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", "touch", "testfile", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
assert await driver.event_queue.get() == (42, JobEvent.COMPLETED)

assert (tmp_path / "testfile").exists()


async def test_failure(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", "false", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
assert await driver.event_queue.get() == (42, JobEvent.FAILED)


async def test_file_not_found(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/file/not/found", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.FAILED)


async def test_kill(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", "sleep", "10", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await driver.kill(42)
assert await driver.event_queue.get() == (42, JobEvent.ABORTED)


@pytest.mark.timeout(5)
async def test_kill_unresponsive_process(monkeypatch, tmp_path):
# Reduce timeout to something more appropriate for a test
monkeypatch.setattr(local_driver, "_TERMINATE_TIMEOUT", 0.1)

(tmp_path / "script").write_text(
"""\
trap "" 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
sleep 60
"""
)

driver = LocalDriver()

await driver.submit(42, "/bin/sh", tmp_path / "script", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await driver.kill(42)
assert await driver.event_queue.get() == (42, JobEvent.ABORTED)


@pytest.mark.parametrize(
"cmd,event", [("true", JobEvent.COMPLETED), ("false", JobEvent.FAILED)]
)
async def test_kill_when_job_completed(tmp_path, cmd, event):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", cmd, cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await asyncio.sleep(0.5)
await driver.kill(42)
assert await driver.event_queue.get() == (42, event)
102 changes: 44 additions & 58 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,12 @@

import pytest

from ert.config.forward_model import ForwardModel
from ert.ensemble_evaluator._builder._realization import Realization
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.run_arg import RunArg
from ert.scheduler import scheduler


def create_bash_step(script: str) -> ForwardModel:
return ForwardModel(
name="bash_step",
executable="/usr/bin/env",
arglist=["bash", "-c", script],
)


def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None:
jobs = {
"global_environment": {},
"config_path": "/dev/null",
"config_file": "/dev/null",
"jobList": [
{
"name": step.name,
"executable": step.executable,
"argList": step.arglist,
}
for step in steps
],
"run_id": "0",
"ert_pid": "0",
"real_id": "0",
}

with open(path / "jobs.json", "w") as f:
json.dump(jobs, f)


@pytest.fixture
def realization(storage, tmp_path):
ensemble = storage.create_experiment().create_ensemble(name="foo", ensemble_size=1)
Expand Down Expand Up @@ -72,41 +41,51 @@ async def test_empty():
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED


async def test_single_job(tmp_path: Path, realization):
step = create_bash_step("echo 'Hello, world!' > testfile")
realization.forward_models = [step]
async def test_single_job(realization, mock_driver):
future = asyncio.Future()

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)
async def init(iens, *args, **kwargs):
future.set_result(iens)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
driver = mock_driver(init=init)

sch = scheduler.Scheduler(driver)
sch.add_realization(realization)

assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "testfile").read_text() == "Hello, world!\n"
assert await future == realization.iens


async def test_cancel(tmp_path: Path, realization):
step = create_bash_step("touch a; sleep 10; touch b")
realization.forward_models = [step]
async def test_cancel(realization, mock_driver):
pre = asyncio.Event()
post = asyncio.Event()
killed = False

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)
async def wait():
pre.set()
await asyncio.sleep(10)
post.set()

async def kill():
nonlocal killed
killed = True

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
driver = mock_driver(wait=wait, kill=kill)
sch = scheduler.Scheduler(driver)
sch.add_realization(realization)

scheduler_task = asyncio.create_task(sch.execute())

# Wait for the job to start (i.e. let the file "a" be touched)
await asyncio.sleep(1)
# Wait for the job to start
await asyncio.wait_for(pre.wait(), timeout=1)

# Kill all jobs and wait for the scheduler to complete
sch.kill_all_jobs()
await scheduler_task

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()
assert pre.is_set()
assert not post.is_set()
assert killed


@pytest.mark.parametrize(
Expand All @@ -117,14 +96,21 @@ async def test_cancel(tmp_path: Path, realization):
(3),
],
)
async def test_that_max_submit_was_reached(tmp_path: Path, realization, max_submit):
script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1"
step = create_bash_step(script)
realization.forward_models = [step]
sch = scheduler.Scheduler()
async def test_that_max_submit_was_reached(realization, max_submit, mock_driver):
retries = 0

async def init(*args, **kwargs):
nonlocal retries
retries += 1

async def wait():
return False

driver = mock_driver(init=init, wait=wait)
sch = scheduler.Scheduler(driver)

sch._max_submit = max_submit
sch.add_realization(realization, callback_timeout=lambda _: None)
create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()

assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "cnt").read_text() == f"{max_submit}\n"
assert retries == max_submit

0 comments on commit af717e2

Please sign in to comment.