Skip to content

Commit

Permalink
Support max_runtime in Scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Dec 22, 2023
1 parent f00c196 commit b22ce4f
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 12 deletions.
25 changes: 16 additions & 9 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def setup_timeout_callback(
cloudevent_unary_send: Callable[[CloudEvent], Awaitable[None]],
event_generator: Callable[[str, Optional[int]], CloudEvent],
) -> Tuple[Callable[[int], None], asyncio.Task[None]]:
"""This function is reimplemented inside the Scheduler and should
be removed when Scheduler is the only queue code."""

def on_timeout(iens: int) -> None:
timeout_queue.put_nowait(
event_generator(identifiers.EVTYPE_REALIZATION_TIMEOUT, iens)
Expand Down Expand Up @@ -173,14 +176,16 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
is a function (or bound method) that only takes a CloudEvent as a positional
argument.
"""
# Set up the timeout-mechanism
timeout_queue = asyncio.Queue() # type: ignore
# Based on the experiment id the generator will
# give a function returning cloud event
event_creator = self.generate_event_creator(experiment_id=experiment_id)
on_timeout, send_timeout_future = self.setup_timeout_callback(
timeout_queue, cloudevent_unary_send, event_creator
)
timeout_queue: Optional[asyncio.Queue[Any]] = None
if not FeatureToggling.is_enabled("scheduler"):
# Set up the timeout-mechanism
timeout_queue = asyncio.Queue()
# Based on the experiment id the generator will
# give a function returning cloud event
on_timeout, send_timeout_future = self.setup_timeout_callback(
timeout_queue, cloudevent_unary_send, event_creator
)

if not self.id_:
raise ValueError("Ensemble id not set")
Expand Down Expand Up @@ -235,8 +240,10 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
)
result = identifiers.EVTYPE_ENSEMBLE_FAILED

await timeout_queue.put(None) # signal to exit timer
await send_timeout_future
if not isinstance(self._job_queue, Scheduler):
assert timeout_queue is not None
await timeout_queue.put(None) # signal to exit timer
await send_timeout_future

# Dispatch final result from evaluator - FAILED, CANCEL or STOPPED
await cloudevent_unary_send(event_creator(result, None))
Expand Down
26 changes: 24 additions & 2 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import asyncio
import logging
import uuid
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent

from ert.callbacks import forward_model_ok
from ert.ensemble_evaluator.identifiers import EVTYPE_REALIZATION_TIMEOUT
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus
from ert.scheduler.driver import Driver
Expand Down Expand Up @@ -67,6 +69,8 @@ def driver(self) -> Driver:

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await sem.acquire()
timeout_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
Expand All @@ -77,6 +81,8 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await self.started.wait()

await self._send(State.RUNNING)
if self.real.max_runtime is not None and self.real.max_runtime > 0:
timeout_task = asyncio.create_task(self._max_runtime_task())
while not self.returncode.done():
await asyncio.sleep(0.01)
returncode = await self.returncode
Expand All @@ -95,10 +101,11 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
if timeout_task and not timeout_task.done():
timeout_task.cancel()
sem.release()

async def __call__(
Expand All @@ -121,6 +128,21 @@ async def __call__(
)
logger.error(message)

async def _max_runtime_task(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
timeout_event = CloudEvent(
{
"type": EVTYPE_REALIZATION_TIMEOUT,
"source": f"/ert/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"id": str(uuid.uuid1()),
}
)
assert self._scheduler._events is not None
await self._scheduler._events.put(to_json(timeout_event))

self.returncode.cancel() # Triggers CancelledError

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
event = CloudEvent(
Expand Down
74 changes: 73 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from pathlib import Path

import pytest
from cloudevents.http import CloudEvent, from_json

from ert.ensemble_evaluator._builder._realization import Realization
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.job_queue.queue import EVTYPE_ENSEMBLE_CANCELLED, EVTYPE_ENSEMBLE_STOPPED
from ert.run_arg import RunArg
from ert.scheduler import scheduler

Expand Down Expand Up @@ -177,3 +178,74 @@ async def wait():

assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert retries == max_submit


@pytest.mark.timeout(10)
async def test_max_runtime(realization, mock_driver):
wait_started = asyncio.Event()

async def wait():
wait_started.set()
await asyncio.sleep(100)

realization.max_runtime = 1

sch = scheduler.Scheduler(mock_driver(wait=wait), [realization])

result = await asyncio.create_task(sch.execute())
assert wait_started.is_set()
assert result == EVTYPE_ENSEMBLE_STOPPED

timeouteventfound = False
while not timeouteventfound and not sch._events.empty():
event = await sch._events.get()
if from_json(event)["type"] == "com.equinor.ert.realization.timeout":
timeouteventfound = True
assert timeouteventfound


@pytest.mark.timeout(6)
async def test_max_runtime_while_killing(realization, mock_driver):
wait_started = asyncio.Event()
now_kill_me = asyncio.Event()

async def wait():
# A realization function that lives forever if it was not killed
wait_started.set()
await asyncio.sleep(0.1)
now_kill_me.set()
await asyncio.sleep(1000)

async def kill():
# A kill function that is triggered before the timeout, but finishes
# after MAX_RUNTIME
await asyncio.sleep(1)

realization.max_runtime = 1

sch = scheduler.Scheduler(mock_driver(wait=wait, kill=kill), [realization])

scheduler_task = asyncio.create_task(sch.execute())

await now_kill_me.wait()
sch.kill_all_jobs()

# Sleep until max_runtime must have kicked in:
await asyncio.sleep(1.1)

timeouteventfound = False
while not timeouteventfound and not sch._events.empty():
event = await sch._events.get()
if from_json(event)["type"] == "com.equinor.ert.realization.timeout":
timeouteventfound = True

# Assert that a timeout_event is actually emitted, because killing took a
# long time, and that we should exit normally (asserting no bad things
# happen just because we have two things killing the realization).

assert timeouteventfound
await scheduler_task

# The result from execute is that we were cancelled, not stopped
# as if the timeout happened before kill_all_jobs()
assert scheduler_task.result() == EVTYPE_ENSEMBLE_CANCELLED

0 comments on commit b22ce4f

Please sign in to comment.