From cf6a93f896d2defe222b853b0dc3376b0a1ded1a Mon Sep 17 00:00:00 2001 From: Zohar Malamant Date: Thu, 7 Dec 2023 16:16:14 +0100 Subject: [PATCH] Add LocalDriver for Scheduler --- src/ert/scheduler/driver.py | 91 ++++++++++++++++++++++++++++++++++ src/ert/scheduler/job.py | 31 +++++++----- src/ert/scheduler/scheduler.py | 35 ++++++++++--- 3 files changed, 137 insertions(+), 20 deletions(-) create mode 100644 src/ert/scheduler/driver.py diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py new file mode 100644 index 00000000000..c30e3be09be --- /dev/null +++ b/src/ert/scheduler/driver.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +import asyncio +import os +from abc import ABC, abstractmethod +from enum import Enum +from typing import ( + MutableMapping, + Optional, + Tuple, +) + + +class JobEvent(Enum): + STARTED = 0 + COMPLETED = 1 + FAILED = 2 + ABORTED = 3 + + +class Driver(ABC): + """Adapter for the HPC cluster.""" + + event_queue: asyncio.Queue[Tuple[int, JobEvent]] + + def __init__(self) -> None: + self.event_queue = asyncio.Queue() + + @abstractmethod + async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: + """Submit a program to execute on the cluster. + + Args: + iens: Realization number. (Unique for each job) + executable: Program to execute. + args: List of arguments to send to the program. + cwd: Working directory. + """ + + @abstractmethod + async def kill(self, iens: int) -> None: + """Terminate execution of a job associated with a realization. + + Args: + iens: Realization number. + """ + + def create_poll_task(self) -> Optional[asyncio.Task[None]]: + """Create a `asyncio.Task` for polling the cluster. + + Returns: + `asyncio.Task`, or None if polling is not applicable (eg. for LocalDriver) + """ + + return None + + +class LocalDriver(Driver): + def __init__(self) -> None: + super().__init__() + self._tasks: MutableMapping[int, asyncio.Task[None]] = {} + + async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: + self._tasks[iens] = asyncio.create_task( + self._wait_until_finish(iens, executable, *args, cwd=cwd) + ) + + async def kill(self, iens: int) -> None: + try: + self._tasks[iens].cancel() + except KeyError: + return + + async def _wait_until_finish( + self, iens: int, executable: str, /, *args: str, cwd: str + ) -> None: + proc = await asyncio.create_subprocess_exec( + executable, + *args, + cwd=cwd, + preexec_fn=os.setpgrp, + ) + await self.event_queue.put((iens, JobEvent.STARTED)) + try: + if await proc.wait() == 0: + await self.event_queue.put((iens, JobEvent.COMPLETED)) + else: + await self.event_queue.put((iens, JobEvent.FAILED)) + except asyncio.CancelledError: + proc.terminate() + await self.event_queue.put((iens, JobEvent.ABORTED)) diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 0cc83943100..e18524b177f 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -1,10 +1,8 @@ from __future__ import annotations import asyncio -import os -import sys from enum import Enum -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from cloudevents.conversion import to_json from cloudevents.http import CloudEvent @@ -12,6 +10,7 @@ from ert.callbacks import forward_model_ok from ert.job_queue.queue import _queue_state_event_type from ert.load_status import LoadStatus +from ert.scheduler.driver import Driver if TYPE_CHECKING: from ert.ensemble_evaluator._builder._realization import Realization @@ -50,31 +49,36 @@ class Job: def __init__(self, scheduler: Scheduler, real: Realization) -> None: self.real = real + self.started = asyncio.Event() + self.returncode: asyncio.Future[int] = asyncio.Future() + self.aborted = asyncio.Event() self._scheduler = scheduler @property def iens(self) -> int: return self.real.iens + @property + def driver(self) -> Driver: + return self._scheduler.driver + async def __call__( self, start: asyncio.Event, sem: asyncio.BoundedSemaphore ) -> None: await start.wait() await sem.acquire() - proc: Optional[asyncio.subprocess.Process] = None try: await self._send(State.SUBMITTING) - - proc = await asyncio.create_subprocess_exec( - sys.executable, - self.real.job_script, - cwd=self.real.run_arg.runpath, - preexec_fn=os.setpgrp, + await self.driver.submit( + self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath ) + await self._send(State.STARTING) + await self.started.wait() + await self._send(State.RUNNING) - returncode = await proc.wait() + returncode = await self.returncode if ( returncode == 0 and forward_model_ok(self.real.run_arg).status @@ -86,8 +90,9 @@ async def __call__( except asyncio.CancelledError: await self._send(State.ABORTING) - if proc: - proc.kill() + await self.driver.kill(self.iens) + + await self.aborted.wait() await self._send(State.ABORTED) finally: sem.release() diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index e57259178db..defe0156aab 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -12,8 +12,8 @@ TYPE_CHECKING, Any, Callable, - Dict, Iterable, + MutableMapping, Optional, ) @@ -22,6 +22,7 @@ from websockets.client import connect from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED +from ert.scheduler.driver import Driver, JobEvent, LocalDriver from ert.scheduler.job import Job if TYPE_CHECKING: @@ -42,9 +43,12 @@ class _JobsJson: class Scheduler: - def __init__(self) -> None: - self._realizations: Dict[int, Job] = {} - self._tasks: Dict[int, asyncio.Task[None]] = {} + def __init__(self, driver: Optional[Driver] = None) -> None: + if driver is None: + driver = LocalDriver() + self.driver = driver + self._jobs: MutableMapping[int, Job] = {} + self._tasks: MutableMapping[int, asyncio.Task[None]] = {} self._events: Queue[Any] = Queue() self._ee_uri = "" @@ -55,7 +59,7 @@ def __init__(self) -> None: def add_realization( self, real: Realization, callback_timeout: Callable[[int], None] ) -> None: - self._realizations[real.iens] = Job(self, real) + self._jobs[real.iens] = Job(self, real) def kill_all_jobs(self) -> None: for task in self._tasks.values(): @@ -95,7 +99,7 @@ async def _publisher(self) -> None: await conn.send(event) def add_dispatch_information_to_jobs_file(self) -> None: - for job in self._realizations.values(): + for job in self._jobs.values(): self._update_jobs_json(job.iens, job.real.run_arg.runpath) async def execute( @@ -107,10 +111,12 @@ async def execute( logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}") publisher_task = asyncio.create_task(self._publisher()) + poller_task = self.driver.create_poll_task() + event_queue_task = asyncio.create_task(self._process_event_queue()) start = asyncio.Event() sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore - for iens, job in self._realizations.items(): + for iens, job in self._jobs.items(): self._tasks[iens] = asyncio.create_task(job(start, sem)) start.set() @@ -118,9 +124,24 @@ async def execute( await task publisher_task.cancel() + event_queue_task.cancel() + if poller_task: + poller_task.cancel() return EVTYPE_ENSEMBLE_STOPPED + async def _process_event_queue(self) -> None: + while True: + iens, event = await self.driver.event_queue.get() + if event == JobEvent.STARTED: + self._jobs[iens].started.set() + elif event == JobEvent.COMPLETED: + self._jobs[iens].returncode.set_result(0) + elif event == JobEvent.FAILED: + self._jobs[iens].returncode.set_result(1) + elif event == JobEvent.ABORTED: + self._jobs[iens].aborted.set() + def _update_jobs_json(self, iens: int, runpath: str) -> None: jobs = _JobsJson( experiment_id="_",