Skip to content

Commit

Permalink
Implement MockDriver
Browse files Browse the repository at this point in the history
LocalDriver differs from the HPC drivers in that it is made of three
parts that are run in sequence. `init` the subprocess, `wait` for the
process to complete and `kill` when the user wants to cancel. Between
the three parts the driver needs to send `JobEvent`s to the `Scheduler`.

This commit implements a `MockDriver`, where the user can optionally
specify a simplified version of each of `init`, `wait` or `kill`,
depending on what they wish to do.

To achieve this, we decople the sending of `JobEvent`s from the
`LocalDriver` into a new driver called `TaskDriver`. `LocalDriver` then
becomes a very bare-bones implementation of just what is needed for
`asyncio.subprocess.Process`.

Thus, it is now possible to use `MockDriver` and specify the following
functions:
- `init` is called with the program arguments and awaited
- `JobEvent.STARTED` is sent
- `wait` is awaited
- Depending on the result of `wait` (`True` or `False`),
  `JobEvent.COMPLETED` or `JobEvent.FAILED` is sent and we are done
- If at any point the user wants to cancel, `kill` is awaited
- `JobEvent.ABORTED` is then sent
  • Loading branch information
pinkwah committed Dec 12, 2023
1 parent 050d02a commit 159b0c7
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 72 deletions.
10 changes: 6 additions & 4 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class Driver(ABC):
"""Adapter for the HPC cluster."""

def __init__(self) -> None:
self.event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None
self._event_queue: Optional[asyncio.Queue[Tuple[int, JobEvent]]] = None

async def ainit(self) -> None:
if self.event_queue is None:
self.event_queue = asyncio.Queue()
@property
def event_queue(self) -> asyncio.Queue[Tuple[int, JobEvent]]:
if self._event_queue is None:
self._event_queue = asyncio.Queue()
return self._event_queue

@abstractmethod
async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
Expand Down
136 changes: 117 additions & 19 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,62 @@

import asyncio
import os
from typing import MutableMapping
from asyncio.subprocess import Process
from typing import (
Any,
Awaitable,
Callable,
Generic,
MutableMapping,
Optional,
Protocol,
TypeVar,
)

from ert.scheduler.driver import Driver, JobEvent

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_contra = TypeVar("_T_contra", contravariant=True)

class LocalDriver(Driver):
def __init__(self) -> None:

class _InitFn(Protocol, Generic[_T_co]):
async def __call__(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> _T_co:
...


_WaitFn = Callable[[_T_contra], Awaitable[bool]]
_KillFn = Callable[[_T_contra], Awaitable[None]]


class TaskDriver(Driver, Generic[_T]):
"""Driver that uses asyncio tasks
Unlike the HPC facing drivers which can `submit`, `kill` and have a
long-running `poll` method, the task driver processes the jobs in a
sequence.
The user of this class need only provide an `init` method, which returns
some type `T` (eg. `asyncio.subprocess.Process` for `LocalDriver`). The
returned object is passed onto the `wait` and `kill` functions.
The split between these 3 functions exists because the `TaskDriver` sends
the appropriate `JobEvent`s during execution.
"""

def __init__(self, init: _InitFn[_T], wait: _WaitFn[_T], kill: _KillFn[_T]) -> None:
super().__init__()
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}
self.__init = init
self.__wait = wait
self.__kill = kill

async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
self._tasks[iens] = asyncio.create_task(
self._wait_until_finish(iens, executable, *args, cwd=cwd)
self._run(iens, executable, *args, cwd=cwd)
)

async def kill(self, iens: int) -> None:
Expand All @@ -23,26 +66,81 @@ async def kill(self, iens: int) -> None:
except KeyError:
return

async def _wait_until_finish(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
proc = await asyncio.create_subprocess_exec(
executable,
*args,
cwd=cwd,
preexec_fn=os.setpgrp,
)

if self.event_queue is None:
await self.ainit()
assert self.event_queue is not None
async def _run(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
try:
obj: _T = await self.__init(
iens,
executable,
*args,
cwd=cwd,
)
except Exception:
await self.event_queue.put((iens, JobEvent.FAILED))
return

await self.event_queue.put((iens, JobEvent.STARTED))
try:
if await proc.wait() == 0:
if await self.__wait(obj):
await self.event_queue.put((iens, JobEvent.COMPLETED))
else:
await self.event_queue.put((iens, JobEvent.FAILED))
except asyncio.CancelledError:
proc.terminate()
await self.__kill(obj)
await self.event_queue.put((iens, JobEvent.ABORTED))


class LocalDriver(TaskDriver[Process]):
def __init__(self) -> None:
super().__init__(LocalDriver._init, LocalDriver._wait, LocalDriver._kill)

@staticmethod
async def _init(iens: int, executable: str, /, *args: str, cwd: str) -> Process:
return await asyncio.create_subprocess_exec(
executable,
*args,
cwd=cwd,
preexec_fn=os.setpgrp,
)

@staticmethod
async def _wait(proc: Process) -> bool:
return await proc.wait() == 0

@staticmethod
async def _kill(proc: Process) -> None:
proc.terminate()
await proc.wait()


class MockDriver(TaskDriver[None]):
"""Driver used for tests
This driver accepts zero or more callbacks without any per-task data,
allowing it to be easily used in tests.
"""

def __init__(
self,
init: Optional[_InitFn[None]] = None,
wait: Optional[Awaitable[Any]] = None,
kill: Optional[Awaitable[None]] = None,
) -> None:
super().__init__(
init or MockDriver._default_init, self._default_wait, self._default_kill
)
self._wait = wait
self._kill = kill

@staticmethod
async def _default_init(*args: Any, **kwargs: Any) -> None:
return

async def _default_wait(self, _: None) -> bool:
if self._wait is not None:
return bool(await self._wait)
return False

async def _default_kill(self, _: None) -> None:
if self._kill is not None:
await self._kill
72 changes: 72 additions & 0 deletions tests/unit_tests/scheduler/test_local_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio

import pytest

from ert.scheduler.driver import JobEvent
from ert.scheduler.local_driver import LocalDriver


async def test_success(tmp_path):
driver = LocalDriver()

# LocalDriver doesn't have a polling task
assert driver.create_poll_task() is None

await driver.submit(42, "/usr/bin/env", "touch", "testfile", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
assert await driver.event_queue.get() == (42, JobEvent.COMPLETED)

assert (tmp_path / "testfile").exists()


async def test_failure(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", "false", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
assert await driver.event_queue.get() == (42, JobEvent.FAILED)


async def test_file_not_found(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/file/not/found", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.FAILED)


async def test_kill(tmp_path):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", "sleep", "10", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await driver.kill(42)
assert await driver.event_queue.get() == (42, JobEvent.ABORTED)


async def test_kill_unresponsive_process(tmp_path):
(tmp_path / "script").write_text(
"""\
trap "" 1 2 3 4 5 6 7 8 10 11 12 13 14 15
sleep 10
"""
)

driver = LocalDriver()

await driver.submit(42, "/bin/sh", tmp_path / "script", cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await driver.kill(42)
assert await driver.event_queue.get() == (42, JobEvent.ABORTED)


@pytest.mark.parametrize(
"cmd,event", [("true", JobEvent.COMPLETED), ("false", JobEvent.FAILED)]
)
async def test_kill_when_job_completed(tmp_path, cmd, event):
driver = LocalDriver()

await driver.submit(42, "/usr/bin/env", cmd, cwd=tmp_path)
assert await driver.event_queue.get() == (42, JobEvent.STARTED)
await asyncio.sleep(0.5)
await driver.kill(42)
assert await driver.event_queue.get() == (42, event)
78 changes: 29 additions & 49 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,41 +7,11 @@

import pytest

from ert.config.forward_model import ForwardModel
from ert.ensemble_evaluator._builder._realization import Realization
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.run_arg import RunArg
from ert.scheduler import scheduler


def create_bash_step(script: str) -> ForwardModel:
return ForwardModel(
name="bash_step",
executable="/usr/bin/env",
arglist=["bash", "-c", script],
)


def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None:
jobs = {
"global_environment": {},
"config_path": "/dev/null",
"config_file": "/dev/null",
"jobList": [
{
"name": step.name,
"executable": step.executable,
"argList": step.arglist,
}
for step in steps
],
"run_id": "0",
"ert_pid": "0",
"real_id": "0",
}

with open(path / "jobs.json", "w") as f:
json.dump(jobs, f)
from ert.scheduler.local_driver import MockDriver


@pytest.fixture
Expand Down Expand Up @@ -73,38 +43,48 @@ async def test_empty():
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED


async def test_single_job(tmp_path: Path, realization):
step = create_bash_step("echo 'Hello, world!' > testfile")
realization.forward_models = [step]
async def test_single_job(realization):
future = asyncio.Future()

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)
async def init(iens, exec, *args, cwd):
future.set_result(iens)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
driver = MockDriver(init=init)

sch = scheduler.Scheduler(driver)
sch.add_realization(realization, callback_timeout=lambda _: None)

assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "testfile").read_text() == "Hello, world!\n"
assert await future == realization.iens


async def test_cancel(tmp_path: Path, realization):
step = create_bash_step("touch a; sleep 10; touch b")
realization.forward_models = [step]
async def test_cancel(realization):
pre = asyncio.Event()
post = asyncio.Event()
killed = False

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)
async def wait():
pre.set()
await asyncio.sleep(10)
post.set()

async def kill():
nonlocal killed
killed = True

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
driver = MockDriver(wait=wait(), kill=kill())
sch = scheduler.Scheduler(driver)
sch.add_realization(realization, callback_timeout=lambda _: None)

scheduler_task = asyncio.create_task(sch.execute())

# Wait for the job to start
await asyncio.sleep(1)
await asyncio.wait_for(pre.wait(), timeout=1)

# Kill all jobs and wait for the scheduler to complete
sch.kill_all_jobs()
await scheduler_task

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()
assert pre.is_set()
assert not post.is_set()
assert killed

0 comments on commit 159b0c7

Please sign in to comment.