Skip to content

Commit

Permalink
Add basic retry loop to account for max_submit functionality
Browse files Browse the repository at this point in the history
Use while retry to iterate from running to waiting states. It includes a
simple test to check if job has started 3 times. Max_submit is a function
parameter of job.__call__ that is passed on from scheduler.
  • Loading branch information
xjules committed Dec 13, 2023
1 parent 225a1e8 commit b0b890c
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 47 deletions.
13 changes: 9 additions & 4 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import asyncio
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Optional,
Tuple,
)
from typing import Optional, Tuple


class JobEvent(Enum):
Expand Down Expand Up @@ -45,6 +42,14 @@ async def kill(self, iens: int) -> None:
iens: Realization number.
"""

@abstractmethod
async def wait(self, iens: int) -> None:
"""Blocks the execution of a job associated with a realization.
Args:
iens: Realization number.
"""

def create_poll_task(self) -> Optional[asyncio.Task[None]]:
"""Create a `asyncio.Task` for polling the cluster.
Expand Down
86 changes: 53 additions & 33 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -16,6 +17,8 @@
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)


class State(str, Enum):
WAITING = "WAITING"
Expand Down Expand Up @@ -63,41 +66,58 @@ def driver(self) -> Driver:
return self._scheduler.driver

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
) -> None:
await start.wait()
await sem.acquire()

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
while not self.returncode.done():
await asyncio.sleep(0.01)
returncode = await self.returncode
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()
retries = 0
retry: bool = True
while retry:
retry = False
await sem.acquire()
try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
while not self.returncode.done():
await asyncio.sleep(0.01)
returncode = await self.returncode
# we need to make sure that the task has finished too
await self.driver.wait(self.real.iens)

if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)
retries += 1
retry = retries < max_submit
if retry:
message = f"Realization: {self.iens} failed, resubmitting"
logger.warning(message)
else:
message = (
f"Realization: {self.iens} "
f"failed after reaching max submit {max_submit}"
)
logger.error(message)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
Expand Down
3 changes: 3 additions & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ async def kill(self, iens: int) -> None:
except KeyError:
return

async def wait(self, iens: int) -> None:
await self._tasks[iens]

async def _wait_until_finish(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
Expand Down
13 changes: 4 additions & 9 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import ssl
import threading
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
MutableMapping,
Optional,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional

from pydantic.dataclasses import dataclass
from websockets import Headers
Expand Down Expand Up @@ -52,6 +45,8 @@ def __init__(self, driver: Optional[Driver] = None) -> None:

self._events: Optional[asyncio.Queue[Any]] = None
self._cancelled = False
# will be read from QueueConfig
self._max_submit: int = 2

self._ee_uri = ""
self._ens_id = ""
Expand Down Expand Up @@ -131,7 +126,7 @@ async def execute(
start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))
self._tasks[iens] = asyncio.create_task(job(start, sem, self._max_submit))

start.set()
for task in self._tasks.values():
Expand Down
17 changes: 16 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from textwrap import dedent
from typing import Sequence
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -108,3 +110,16 @@ async def test_cancel(tmp_path: Path, realization):

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()


async def test_that_max_submit_was_reached(tmp_path: Path, realization):
script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1"
step = create_bash_step(script)
realization.forward_models = [step]
sch = scheduler.Scheduler()
sch._max_submit = 3
sch.add_realization(realization, callback_timeout=lambda _: None)
create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "cnt").read_text() == "3\n"

0 comments on commit b0b890c

Please sign in to comment.