Skip to content

Commit

Permalink
Add LocalDriver for Scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah committed Dec 8, 2023
1 parent 51c2b7b commit cf6a93f
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 20 deletions.
91 changes: 91 additions & 0 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

import asyncio
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
MutableMapping,
Optional,
Tuple,
)


class JobEvent(Enum):
STARTED = 0
COMPLETED = 1
FAILED = 2
ABORTED = 3


class Driver(ABC):
"""Adapter for the HPC cluster."""

event_queue: asyncio.Queue[Tuple[int, JobEvent]]

def __init__(self) -> None:
self.event_queue = asyncio.Queue()

@abstractmethod
async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
"""Submit a program to execute on the cluster.
Args:
iens: Realization number. (Unique for each job)
executable: Program to execute.
args: List of arguments to send to the program.
cwd: Working directory.
"""

@abstractmethod
async def kill(self, iens: int) -> None:
"""Terminate execution of a job associated with a realization.
Args:
iens: Realization number.
"""

def create_poll_task(self) -> Optional[asyncio.Task[None]]:
"""Create a `asyncio.Task` for polling the cluster.
Returns:
`asyncio.Task`, or None if polling is not applicable (eg. for LocalDriver)
"""

return None


class LocalDriver(Driver):
def __init__(self) -> None:
super().__init__()
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
self._tasks[iens] = asyncio.create_task(
self._wait_until_finish(iens, executable, *args, cwd=cwd)
)

async def kill(self, iens: int) -> None:
try:
self._tasks[iens].cancel()
except KeyError:
return

async def _wait_until_finish(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
proc = await asyncio.create_subprocess_exec(
executable,
*args,
cwd=cwd,
preexec_fn=os.setpgrp,
)
await self.event_queue.put((iens, JobEvent.STARTED))
try:
if await proc.wait() == 0:
await self.event_queue.put((iens, JobEvent.COMPLETED))
else:
await self.event_queue.put((iens, JobEvent.FAILED))
except asyncio.CancelledError:
proc.terminate()
await self.event_queue.put((iens, JobEvent.ABORTED))
31 changes: 18 additions & 13 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations

import asyncio
import os
import sys
from enum import Enum
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent

from ert.callbacks import forward_model_ok
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization
Expand Down Expand Up @@ -50,31 +49,36 @@ class Job:

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
self.real = real
self.started = asyncio.Event()
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler = scheduler

@property
def iens(self) -> int:
return self.real.iens

@property
def driver(self) -> Driver:
return self._scheduler.driver

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
) -> None:
await start.wait()
await sem.acquire()

proc: Optional[asyncio.subprocess.Process] = None
try:
await self._send(State.SUBMITTING)

proc = await asyncio.create_subprocess_exec(
sys.executable,
self.real.job_script,
cwd=self.real.run_arg.runpath,
preexec_fn=os.setpgrp,
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
returncode = await proc.wait()
returncode = await self.returncode
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
Expand All @@ -86,8 +90,9 @@ async def __call__(

except asyncio.CancelledError:
await self._send(State.ABORTING)
if proc:
proc.kill()
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()
Expand Down
35 changes: 28 additions & 7 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
MutableMapping,
Optional,
)

Expand All @@ -22,6 +22,7 @@
from websockets.client import connect

from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.driver import Driver, JobEvent, LocalDriver
from ert.scheduler.job import Job

if TYPE_CHECKING:
Expand All @@ -42,9 +43,12 @@ class _JobsJson:


class Scheduler:
def __init__(self) -> None:
self._realizations: Dict[int, Job] = {}
self._tasks: Dict[int, asyncio.Task[None]] = {}
def __init__(self, driver: Optional[Driver] = None) -> None:
if driver is None:
driver = LocalDriver()
self.driver = driver
self._jobs: MutableMapping[int, Job] = {}
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}
self._events: Queue[Any] = Queue()

self._ee_uri = ""
Expand All @@ -55,7 +59,7 @@ def __init__(self) -> None:
def add_realization(
self, real: Realization, callback_timeout: Callable[[int], None]
) -> None:
self._realizations[real.iens] = Job(self, real)
self._jobs[real.iens] = Job(self, real)

def kill_all_jobs(self) -> None:
for task in self._tasks.values():
Expand Down Expand Up @@ -95,7 +99,7 @@ async def _publisher(self) -> None:
await conn.send(event)

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._realizations.values():
for job in self._jobs.values():
self._update_jobs_json(job.iens, job.real.run_arg.runpath)

async def execute(
Expand All @@ -107,20 +111,37 @@ async def execute(
logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}")

publisher_task = asyncio.create_task(self._publisher())
poller_task = self.driver.create_poll_task()
event_queue_task = asyncio.create_task(self._process_event_queue())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
for iens, job in self._realizations.items():
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))

start.set()
for task in self._tasks.values():
await task

publisher_task.cancel()
event_queue_task.cancel()
if poller_task:
poller_task.cancel()

return EVTYPE_ENSEMBLE_STOPPED

async def _process_event_queue(self) -> None:
while True:
iens, event = await self.driver.event_queue.get()
if event == JobEvent.STARTED:
self._jobs[iens].started.set()
elif event == JobEvent.COMPLETED:
self._jobs[iens].returncode.set_result(0)
elif event == JobEvent.FAILED:
self._jobs[iens].returncode.set_result(1)
elif event == JobEvent.ABORTED:
self._jobs[iens].aborted.set()

def _update_jobs_json(self, iens: int, runpath: str) -> None:
jobs = _JobsJson(
experiment_id="_",
Expand Down

0 comments on commit cf6a93f

Please sign in to comment.