Skip to content

Commit

Permalink
Make sure event loop is running before running coroutines in executor
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Jan 9, 2025
1 parent 7127528 commit 9934398
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 21 deletions.
8 changes: 2 additions & 6 deletions src/ert/ensemble_evaluator/_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass
from functools import partialmethod
from typing import Any, Protocol
from typing import Any

from _ert.events import (
Event,
Expand Down Expand Up @@ -104,7 +104,7 @@ class LegacyEnsemble:
id_: str

def __post_init__(self) -> None:
self._scheduler: _KillAllJobs | None = None
self._scheduler: Scheduler | None = None
self._config: EvaluatorServerConfig | None = None
self.snapshot: EnsembleSnapshot = self._create_snapshot()
self.status = self.snapshot.status
Expand Down Expand Up @@ -315,10 +315,6 @@ def cancel(self) -> None:
logger.debug("evaluator cancelled")


class _KillAllJobs(Protocol):
def kill_all_jobs(self) -> None: ...


@dataclass
class Realization:
iens: int
Expand Down
9 changes: 4 additions & 5 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None:
if len(events) == 0:
events = [EnsembleFailed(ensemble=self.ensemble.id_)]
await self._append_message(self.ensemble.update_snapshot(events))
self._signal_cancel() # let ensemble know it should stop
await self._signal_cancel() # let ensemble know it should stop

@property
def ensemble(self) -> Ensemble:
Expand All @@ -216,7 +216,7 @@ async def handle_client(self, dealer: bytes, frame: bytes) -> None:
event = event_from_json(frame.decode("utf-8"))
if type(event) is EEUserCancel:
logger.debug("Client asked to cancel.")
self._signal_cancel()
await self._signal_cancel()
elif type(event) is EEUserDone:
logger.debug("Client signalled done.")
self.stop()
Expand Down Expand Up @@ -322,7 +322,7 @@ async def _server(self) -> None:
def stop(self) -> None:
self._server_done.set()

def _signal_cancel(self) -> None:
async def _signal_cancel(self) -> None:
"""
This is just a wrapper around logic for whether to signal cancel via
a cancellable ensemble or to use internal stop-mechanism directly
Expand All @@ -333,8 +333,7 @@ def _signal_cancel(self) -> None:
"""
if self._ensemble.cancellable:
logger.debug("Cancelling current ensemble")
assert self._loop is not None
self._loop.run_in_executor(None, self._ensemble.cancel)
await self._ensemble.cancel()

Check failure on line 336 in src/ert/ensemble_evaluator/evaluator.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

"cancel" of "LegacyEnsemble" does not return a value (it only ever returns None)
else:
logger.debug("Stopping current ensemble")
self.stop()
Expand Down
9 changes: 3 additions & 6 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,9 @@ async def run_monitor(
EESnapshotUpdate,
}:
event = cast(EESnapshot | EESnapshotUpdate, event)
await asyncio.get_running_loop().run_in_executor(
None,
self.send_snapshot_event,
event,
iteration,
)

self.send_snapshot_event(event, iteration)

if event.snapshot.get(STATUS) in {
ENSEMBLE_STATE_STOPPED,
ENSEMBLE_STATE_FAILED,
Expand Down
6 changes: 2 additions & 4 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,10 @@ def __init__(

self.checksum: dict[str, dict[str, Any]] = {}

def kill_all_jobs(self) -> None:
assert self._loop
async def kill_all_jobs(self) -> None:
# Checking that the loop is running is required because everest is closing the
# simulation context whenever an optimization simulation batch is done
if self._loop.is_running():
asyncio.run_coroutine_threadsafe(self.cancel_all_jobs(), self._loop)
await self.cancel_all_jobs()

async def cancel_all_jobs(self) -> None:
await self._running.wait()
Expand Down

0 comments on commit 9934398

Please sign in to comment.