Skip to content

Commit

Permalink
Make scheduler cancel() and kill_all_jobs async
Browse files Browse the repository at this point in the history
This commit makes the methods async, so we can await them instead
fire-and-forgetting them through asyncio.run_coroutine_threadsafe.
  • Loading branch information
jonathan-eq committed Jan 10, 2025
1 parent 7127528 commit 4c5af60
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 31 deletions.
12 changes: 4 additions & 8 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from functools import partialmethod
from typing import Any, Protocol
from typing import Any

from _ert.events import (
Event,
Expand Down Expand Up @@ -104,7 +104,7 @@ class LegacyEnsemble:
id_: str

def __post_init__(self) -> None:
self._scheduler: _KillAllJobs | None = None
self._scheduler: Scheduler | None = None
self._config: EvaluatorServerConfig | None = None
self.snapshot: EnsembleSnapshot = self._create_snapshot()
self.status = self.snapshot.status
Expand Down Expand Up @@ -309,16 +309,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
def cancellable(self) -> bool:
return True

def cancel(self) -> None:
async def cancel(self) -> None:
if self._scheduler is not None:
self._scheduler.kill_all_jobs()
await self._scheduler.kill_all_jobs()
logger.debug("evaluator cancelled")


class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None: ...


@dataclass
class Realization:
iens: int
Expand Down
12 changes: 4 additions & 8 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._config: EvaluatorServerConfig = config
self._ensemble: Ensemble = ensemble

self._loop: asyncio.AbstractEventLoop | None = None

self._events: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()
Expand Down Expand Up @@ -190,7 +188,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
if len(events) == 0:
events = [EnsembleFailed(ensemble=self.ensemble.id_)]
await self._append_message(self.ensemble.update_snapshot(events))
self._signal_cancel() # let ensemble know it should stop
await self._signal_cancel() # let ensemble know it should stop

@property
def ensemble(self) -> Ensemble:
Expand All @@ -216,7 +214,7 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:
event = event_from_json(frame.decode("utf-8"))
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
await self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()
Expand Down Expand Up @@ -322,7 +320,7 @@ async def _server(self) -> None:
def stop(self) -> None:
self._server_done.set()

def _signal_cancel(self) -> None:
async def _signal_cancel(self) -> None:
"""
This is just a wrapper around logic for whether to signal cancel via
a cancellable ensemble or to use internal stop-mechanism directly
Expand All @@ -333,16 +331,14 @@ def _signal_cancel(self) -> None:
"""
if self._ensemble.cancellable:
logger.debug("Cancelling current ensemble")
assert self._loop is not None
self._loop.run_in_executor(None, self._ensemble.cancel)
await self._ensemble.cancel()
else:
logger.debug("Stopping current ensemble")
self.stop()

async def _start_running(self) -> None:
if not self._config:
raise ValueError("no config for evaluator")
self._loop = asyncio.get_running_loop()
self._ee_tasks = [
asyncio.create_task(self._server(), name="server_task"),
asyncio.create_task(
Expand Down
1 change: 0 additions & 1 deletion src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ async def process_message(self, msg: str) -> None:
async def signal_cancel(self) -> None:
await self._event_queue.put(Monitor._sentinel)
logger.debug(f"monitor-{self._id} asking server to cancel...")

cancel_event = EEUserCancel(monitor=self._id)
await self.send(event_to_json(cancel_event))
logger.debug(f"monitor-{self._id} asked server to cancel")
Expand Down
9 changes: 3 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,9 @@ async def run_monitor(
EESnapshotUpdate,
}:
event = cast(EESnapshot | EESnapshotUpdate, event)
await asyncio.get_running_loop().run_in_executor(
None,
self.send_snapshot_event,
event,
iteration,
)

self.send_snapshot_event(event, iteration)

if event.snapshot.get(STATUS) in {
ENSEMBLE_STATE_STOPPED,
ENSEMBLE_STATE_FAILED,
Expand Down
10 changes: 2 additions & 8 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import orjson
from pydantic.dataclasses import dataclass

from _ert.async_utils import get_running_loop
from _ert.events import Event, ForwardModelStepChecksum, Id, event_from_dict

from .driver import Driver
Expand Down Expand Up @@ -82,7 +81,6 @@ def __init__(
real.iens: Job(self, real) for real in (realizations or [])
}

self._loop = get_running_loop()
self._events: asyncio.Queue[Any] = asyncio.Queue()
self._running: asyncio.Event = asyncio.Event()

Expand All @@ -103,12 +101,8 @@ def __init__(

self.checksum: dict[str, dict[str, Any]] = {}

def kill_all_jobs(self) -> None:
assert self._loop
# Checking that the loop is running is required because everest is closing the
# simulation context whenever an optimization simulation batch is done
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop)
async def kill_all_jobs(self) -> None:
await self.cancel_all_jobs()

async def cancel_all_jobs(self) -> None:
await self._running.wait()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def test_run_and_cancel_legacy_ensemble(
# and the ensemble is set to STOPPED
monitor._receiver_timeout = 10.0
cancel = True
await evaluator._ensemble._scheduler._running.wait()
async for event in monitor.track(heartbeat_interval=0.1):
# Cancel the ensemble upon the arrival of the first event
if cancel:
Expand Down

0 comments on commit 4c5af60

Please sign in to comment.