Skip to content

Commit

Permalink
Support max_runtime in scheduler with callback
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Dec 12, 2023
1 parent bb2af36 commit 56b2ed2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
34 changes: 27 additions & 7 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
from enum import Enum
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
Expand Down Expand Up @@ -47,12 +47,18 @@ class Job:
(LSF, PBS, SLURM, etc.)
"""

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
def __init__(
self,
scheduler: Scheduler,
real: Realization,
callback_timeout: Optional[Callable[[int], None]] = None,
) -> None:
self.real = real
self.started = asyncio.Event()
self.returncode: asyncio.Future[int] = asyncio.Future()
self.aborted = asyncio.Event()
self._scheduler = scheduler
self._callback_timeout: Optional[Callable[[int], None]] = callback_timeout

@property
def iens(self) -> int:
Expand All @@ -67,6 +73,7 @@ async def __call__(
) -> None:
await start.wait()
await sem.acquire()
cancel_task: Optional[asyncio.Task[None]] = None

try:
await self._send(State.SUBMITTING)
Expand All @@ -78,6 +85,8 @@ async def __call__(
await self.started.wait()

await self._send(State.RUNNING)
if self.real.max_runtime is not None and self.real.max_runtime > 0:
cancel_task = asyncio.create_task(self._cancel_from_timeout())
returncode = await self.returncode
if (
returncode == 0
Expand All @@ -89,14 +98,25 @@ async def __call__(
await self._send(State.FAILED)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
await self._cancel()
finally:
if cancel_task and not cancel_task.done():
cancel_task.cancel()
sem.release()

async def _cancel_from_timeout(self) -> None:
assert self.real.max_runtime is not None
await asyncio.sleep(self.real.max_runtime)
if self._callback_timeout is not None:
self._callback_timeout(self.real.iens)
self.returncode.cancel() # Triggers CancelledError

async def _cancel(self) -> None:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)
await self.aborted.wait()
await self._send(State.ABORTED)

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
event = CloudEvent(
Expand Down
6 changes: 4 additions & 2 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ async def ainit(self) -> None:
self._events = asyncio.Queue()

def add_realization(
self, real: Realization, callback_timeout: Callable[[int], None]
self,
real: Realization,
callback_timeout: Optional[Callable[[int], None]] = None,
) -> None:
self._jobs[real.iens] = Job(self, real)
self._jobs[real.iens] = Job(self, real, callback_timeout=callback_timeout)

def kill_all_jobs(self) -> None:
for task in self._tasks.values():
Expand Down
24 changes: 23 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def test_cancel(tmp_path: Path, realization):
realization.forward_models = [step]

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)
sch.add_realization(realization)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
Expand All @@ -108,3 +108,25 @@ async def test_cancel(tmp_path: Path, realization):

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()


async def test_max_runtime_with_callback_timeout(tmp_path: Path, realization):
step = create_bash_step("touch a; sleep 10; touch b")
realization.forward_models = [step]
realization.max_runtime = 1

def mocked_callback(iens: int):
Path("from_callback").write_text(str(iens), encoding="utf-8")

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=mocked_callback)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()

scheduler_task = asyncio.create_task(sch.execute())
await scheduler_task

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()
assert Path("from_callback").read_text(encoding="utf-8") == str(realization.iens)

0 comments on commit 56b2ed2

Please sign in to comment.