diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 05dcb38ecb1..2f7c9cbbdcf 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partialmethod -from typing import Any, Protocol +from typing import Any from _ert.events import ( Event, @@ -104,7 +104,7 @@ class LegacyEnsemble: id_: str def __post_init__(self) -> None: - self._scheduler: _KillAllJobs | None = None + self._scheduler: Scheduler | None = None self._config: EvaluatorServerConfig | None = None self.snapshot: EnsembleSnapshot = self._create_snapshot() self.status = self.snapshot.status @@ -309,16 +309,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches def cancellable(self) -> bool: return True - def cancel(self) -> None: + async def cancel(self) -> None: if self._scheduler is not None: - self._scheduler.kill_all_jobs() + await self._scheduler.kill_all_jobs() logger.debug("evaluator cancelled") -class _KillAllJobs(Protocol): - def kill_all_jobs(self) -> None: ... - - @dataclass class Realization: iens: int diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 3696660524f..c132abca5ca 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -50,8 +50,6 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._config: EvaluatorServerConfig = config self._ensemble: Ensemble = ensemble - self._loop: asyncio.AbstractEventLoop | None = None - self._events: asyncio.Queue[Event] = asyncio.Queue() self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() @@ -190,7 +188,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: if len(events) == 0: events = [EnsembleFailed(ensemble=self.ensemble.id_)] await self._append_message(self.ensemble.update_snapshot(events)) - self._signal_cancel() # let ensemble know it should stop + await self._signal_cancel() # let ensemble know it should stop @property def ensemble(self) -> Ensemble: @@ -216,7 +214,7 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None: event = event_from_json(frame.decode("utf-8")) if type(event) is EEUserCancel: logger.debug("Client asked to cancel.") - self._signal_cancel() + await self._signal_cancel() elif type(event) is EEUserDone: logger.debug("Client signalled done.") self.stop() @@ -322,7 +320,7 @@ async def _server(self) -> None: def stop(self) -> None: self._server_done.set() - def _signal_cancel(self) -> None: + async def _signal_cancel(self) -> None: """ This is just a wrapper around logic for whether to signal cancel via a cancellable ensemble or to use internal stop-mechanism directly @@ -333,8 +331,7 @@ def _signal_cancel(self) -> None: """ if self._ensemble.cancellable: logger.debug("Cancelling current ensemble") - assert self._loop is not None - self._loop.run_in_executor(None, self._ensemble.cancel) + await self._ensemble.cancel() else: logger.debug("Stopping current ensemble") self.stop() @@ -342,7 +339,6 @@ def _signal_cancel(self) -> None: async def _start_running(self) -> None: if not self._config: raise ValueError("no config for evaluator") - self._loop = asyncio.get_running_loop() self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), asyncio.create_task( diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index d55a50b9661..8b95af490df 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -47,7 +47,6 @@ async def process_message(self, msg: str) -> None: async def signal_cancel(self) -> None: await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") - cancel_event = EEUserCancel(monitor=self._id) await self.send(event_to_json(cancel_event)) logger.debug(f"monitor-{self._id} asked server to cancel") diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index bd6392647de..5a16736ff8a 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -503,12 +503,9 @@ async def run_monitor( EESnapshotUpdate, }: event = cast(EESnapshot | EESnapshotUpdate, event) - await asyncio.get_running_loop().run_in_executor( - None, - self.send_snapshot_event, - event, - iteration, - ) + + self.send_snapshot_event(event, iteration) + if event.snapshot.get(STATUS) in { ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 6495ec28052..8ca13446fb0 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -14,7 +14,6 @@ import orjson from pydantic.dataclasses import dataclass -from _ert.async_utils import get_running_loop from _ert.events import Event, ForwardModelStepChecksum, Id, event_from_dict from .driver import Driver @@ -82,7 +81,6 @@ def __init__( real.iens: Job(self, real) for real in (realizations or []) } - self._loop = get_running_loop() self._events: asyncio.Queue[Any] = asyncio.Queue() self._running: asyncio.Event = asyncio.Event() @@ -103,12 +101,8 @@ def __init__( self.checksum: dict[str, dict[str, Any]] = {} - def kill_all_jobs(self) -> None: - assert self._loop - # Checking that the loop is running is required because everest is closing the - # simulation context whenever an optimization simulation batch is done - if self._loop.is_running(): - asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop) + async def kill_all_jobs(self) -> None: + await self.cancel_all_jobs() async def cancel_all_jobs(self) -> None: await self._running.wait() diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index f11d693b15a..59fc4da5405 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -89,6 +89,7 @@ async def test_run_and_cancel_legacy_ensemble( # and the ensemble is set to STOPPED monitor._receiver_timeout = 10.0 cancel = True + await evaluator._ensemble._scheduler._running.wait() async for event in monitor.track(heartbeat_interval=0.1): # Cancel the ensemble upon the arrival of the first event if cancel: