From 99343989501fc553a550369e9c5d98c237d779c6 Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Mon, 6 Jan 2025 14:05:22 +0100 Subject: [PATCH] Make sure event loop is running before running coroutines in executor --- src/ert/ensemble_evaluator/_ensemble.py | 8 ++------ src/ert/ensemble_evaluator/evaluator.py | 9 ++++----- src/ert/run_models/base_run_model.py | 9 +++------ src/ert/scheduler/scheduler.py | 6 ++---- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 05dcb38ecb1..9cf992317ae 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass from functools import partialmethod -from typing import Any, Protocol +from typing import Any from _ert.events import ( Event, @@ -104,7 +104,7 @@ class LegacyEnsemble: id_: str def __post_init__(self) -> None: - self._scheduler: _KillAllJobs | None = None + self._scheduler: Scheduler | None = None self._config: EvaluatorServerConfig | None = None self.snapshot: EnsembleSnapshot = self._create_snapshot() self.status = self.snapshot.status @@ -315,10 +315,6 @@ def cancel(self) -> None: logger.debug("evaluator cancelled") -class _KillAllJobs(Protocol): - def kill_all_jobs(self) -> None: ... - - @dataclass class Realization: iens: int diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 3696660524f..1434ea8960c 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -190,7 +190,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: if len(events) == 0: events = [EnsembleFailed(ensemble=self.ensemble.id_)] await self._append_message(self.ensemble.update_snapshot(events)) - self._signal_cancel() # let ensemble know it should stop + await self._signal_cancel() # let ensemble know it should stop @property def ensemble(self) -> Ensemble: @@ -216,7 +216,7 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None: event = event_from_json(frame.decode("utf-8")) if type(event) is EEUserCancel: logger.debug("Client asked to cancel.") - self._signal_cancel() + await self._signal_cancel() elif type(event) is EEUserDone: logger.debug("Client signalled done.") self.stop() @@ -322,7 +322,7 @@ async def _server(self) -> None: def stop(self) -> None: self._server_done.set() - def _signal_cancel(self) -> None: + async def _signal_cancel(self) -> None: """ This is just a wrapper around logic for whether to signal cancel via a cancellable ensemble or to use internal stop-mechanism directly @@ -333,8 +333,7 @@ def _signal_cancel(self) -> None: """ if self._ensemble.cancellable: logger.debug("Cancelling current ensemble") - assert self._loop is not None - self._loop.run_in_executor(None, self._ensemble.cancel) + await self._ensemble.cancel() else: logger.debug("Stopping current ensemble") self.stop() diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index bd6392647de..5a16736ff8a 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -503,12 +503,9 @@ async def run_monitor( EESnapshotUpdate, }: event = cast(EESnapshot | EESnapshotUpdate, event) - await asyncio.get_running_loop().run_in_executor( - None, - self.send_snapshot_event, - event, - iteration, - ) + + self.send_snapshot_event(event, iteration) + if event.snapshot.get(STATUS) in { ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 6495ec28052..e69686e7f54 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -103,12 +103,10 @@ def __init__( self.checksum: dict[str, dict[str, Any]] = {} - def kill_all_jobs(self) -> None: - assert self._loop + async def kill_all_jobs(self) -> None: # Checking that the loop is running is required because everest is closing the # simulation context whenever an optimization simulation batch is done - if self._loop.is_running(): - asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop) + await self.cancel_all_jobs() async def cancel_all_jobs(self) -> None: await self._running.wait()