Skip to content

Commit

Permalink
Emit the timeout event from job.py, not through callback
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Dec 15, 2023
1 parent 7043961 commit 89fa980
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
26 changes: 17 additions & 9 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def setup_timeout_callback(
cloudevent_unary_send: Callable[[CloudEvent], Awaitable[None]],
event_generator: Callable[[str, Optional[int]], CloudEvent],
) -> Tuple[Callable[[int], None], asyncio.Task[None]]:
"""This function is reimplemented inside the Scheduler and should
be removed when Scheduler is the only queue code."""

def on_timeout(iens: int) -> None:
timeout_queue.put_nowait(
event_generator(identifiers.EVTYPE_REALIZATION_TIMEOUT, iens)
Expand Down Expand Up @@ -165,14 +168,15 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
is a function (or bound method) that only takes a CloudEvent as a positional
argument.
"""
# Set up the timeout-mechanism
timeout_queue = asyncio.Queue() # type: ignore
# Based on the experiment id the generator will
# give a function returning cloud event
event_creator = self.generate_event_creator(experiment_id=experiment_id)
on_timeout, send_timeout_future = self.setup_timeout_callback(
timeout_queue, cloudevent_unary_send, event_creator
)
if not isinstance(self._job_queue, Scheduler):
# Set up the timeout-mechanism
timeout_queue = asyncio.Queue() # type: ignore
# Based on the experiment id the generator will
# give a function returning cloud event
event_creator = self.generate_event_creator(experiment_id=experiment_id)
on_timeout, send_timeout_future = self.setup_timeout_callback(
timeout_queue, cloudevent_unary_send, event_creator
)

if not self.id_:
raise ValueError("Ensemble id not set")
Expand All @@ -185,7 +189,11 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
)

for real in self.active_reals:
self._job_queue.add_realization(real, callback_timeout=on_timeout)
if isinstance(self._job_queue, Scheduler):
# Scheduler will always try to publish a timeout event.
self._job_queue.add_realization(real)
else:
self._job_queue.add_realization(real, callback_timeout=on_timeout)

# TODO: this is sort of a callback being preemptively called.
# It should be lifted out of the queue/evaluate, into the evaluator. If
Expand Down
14 changes: 11 additions & 3 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import asyncio
import logging
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Callable, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent

from ert.callbacks import forward_model_ok
from ert.ensemble_evaluator.identifiers import EVTYPE_REALIZATION_TIMEOUT
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver
Expand Down Expand Up @@ -61,7 +63,6 @@ def __init__(
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler = scheduler
self._callback_timeout: Optional[Callable[[int], None]] = callback_timeout

@property
def iens(self) -> int:
Expand Down Expand Up @@ -135,8 +136,15 @@ async def __call__(
async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
if self._callback_timeout is not None:
self._callback_timeout(self.real.iens)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError

async def _send(self, state: State) -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence

import pytest
from cloudevents.http import CloudEvent, from_json

from ert.config.forward_model import ForwardModel
from ert.ensemble_evaluator._builder._realization import Realization
Expand Down Expand Up @@ -147,6 +148,13 @@ def mocked_callback(iens: int):
scheduler_task = asyncio.create_task(sch.execute())
await scheduler_task

timeouteventfound = False
while not timeouteventfound and not sch._events.empty():
event = await sch._events.get()
if from_json(event)["type"] == "com.equinor.ert.realization.timeout":
timeouteventfound = True
assert timeouteventfound

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()
assert Path("from_callback").read_text(encoding="utf-8") == str(realization.iens)

0 comments on commit 89fa980

Please sign in to comment.