From b0b890ccc48772f7da0baace413d0c0f50ec5442 Mon Sep 17 00:00:00 2001 From: Julius Parulek Date: Mon, 11 Dec 2023 14:15:54 +0100 Subject: [PATCH] Add basic retry loop to account for max_submit functionality Use while retry to iterate from running to waiting states. It includes a simple test to check if job has started 3 times. Max_submit is a function parameter of job.__call__ that is passed on from scheduler. --- src/ert/scheduler/driver.py | 13 ++- src/ert/scheduler/job.py | 86 ++++++++++++-------- src/ert/scheduler/local_driver.py | 3 + src/ert/scheduler/scheduler.py | 13 +-- tests/unit_tests/scheduler/test_scheduler.py | 17 +++- 5 files changed, 85 insertions(+), 47 deletions(-) diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 3467c964e5b..2dc710a3f92 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -3,10 +3,7 @@ import asyncio from abc import ABC, abstractmethod from enum import Enum -from typing import ( - Optional, - Tuple, -) +from typing import Optional, Tuple class JobEvent(Enum): @@ -45,6 +42,14 @@ async def kill(self, iens: int) -> None: iens: Realization number. """ + @abstractmethod + async def wait(self, iens: int) -> None: + """Blocks the execution of a job associated with a realization. + + Args: + iens: Realization number. + """ + def create_poll_task(self) -> Optional[asyncio.Task[None]]: """Create a `asyncio.Task` for polling the cluster. diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 022b878f73e..da22418321a 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from enum import Enum from typing import TYPE_CHECKING @@ -16,6 +17,8 @@ from ert.ensemble_evaluator._builder._realization import Realization from ert.scheduler.scheduler import Scheduler +logger = logging.getLogger(__name__) + class State(str, Enum): WAITING = "WAITING" @@ -63,41 +66,58 @@ def driver(self) -> Driver: return self._scheduler.driver async def __call__( - self, start: asyncio.Event, sem: asyncio.BoundedSemaphore + self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2 ) -> None: await start.wait() - await sem.acquire() - - try: - await self._send(State.SUBMITTING) - await self.driver.submit( - self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath - ) - - await self._send(State.STARTING) - await self.started.wait() - - await self._send(State.RUNNING) - while not self.returncode.done(): - await asyncio.sleep(0.01) - returncode = await self.returncode - if ( - returncode == 0 - and forward_model_ok(self.real.run_arg).status - == LoadStatus.LOAD_SUCCESSFUL - ): - await self._send(State.COMPLETED) - else: - await self._send(State.FAILED) - - except asyncio.CancelledError: - await self._send(State.ABORTING) - await self.driver.kill(self.iens) - - await self.aborted.wait() - await self._send(State.ABORTED) - finally: - sem.release() + retries = 0 + retry: bool = True + while retry: + retry = False + await sem.acquire() + try: + await self._send(State.SUBMITTING) + await self.driver.submit( + self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath + ) + + await self._send(State.STARTING) + await self.started.wait() + + await self._send(State.RUNNING) + while not self.returncode.done(): + await asyncio.sleep(0.01) + returncode = await self.returncode + # we need to make sure that the task has finished too + await self.driver.wait(self.real.iens) + + if ( + returncode == 0 + and forward_model_ok(self.real.run_arg).status + == LoadStatus.LOAD_SUCCESSFUL + ): + await self._send(State.COMPLETED) + else: + await self._send(State.FAILED) + retries += 1 + retry = retries < max_submit + if retry: + message = f"Realization: {self.iens} failed, resubmitting" + logger.warning(message) + else: + message = ( + f"Realization: {self.iens} " + f"failed after reaching max submit {max_submit}" + ) + logger.error(message) + + except asyncio.CancelledError: + await self._send(State.ABORTING) + await self.driver.kill(self.iens) + + await self.aborted.wait() + await self._send(State.ABORTED) + finally: + sem.release() async def _send(self, state: State) -> None: status = STATE_TO_LEGACY[state] diff --git a/src/ert/scheduler/local_driver.py b/src/ert/scheduler/local_driver.py index 960d98e9bc4..a9fc6285c90 100644 --- a/src/ert/scheduler/local_driver.py +++ b/src/ert/scheduler/local_driver.py @@ -23,6 +23,9 @@ async def kill(self, iens: int) -> None: except KeyError: return + async def wait(self, iens: int) -> None: + await self._tasks[iens] + async def _wait_until_finish( self, iens: int, executable: str, /, *args: str, cwd: str ) -> None: diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 5b900479e5c..5ec144472f2 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -7,14 +7,7 @@ import ssl import threading from dataclasses import asdict -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - MutableMapping, - Optional, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional from pydantic.dataclasses import dataclass from websockets import Headers @@ -52,6 +45,8 @@ def __init__(self, driver: Optional[Driver] = None) -> None: self._events: Optional[asyncio.Queue[Any]] = None self._cancelled = False + # will be read from QueueConfig + self._max_submit: int = 2 self._ee_uri = "" self._ens_id = "" @@ -131,7 +126,7 @@ async def execute( start = asyncio.Event() sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore for iens, job in self._jobs.items(): - self._tasks[iens] = asyncio.create_task(job(start, sem)) + self._tasks[iens] = asyncio.create_task(job(start, sem, self._max_submit)) start.set() for task in self._tasks.values(): diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 3b071e109a4..7b25158faef 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -1,9 +1,11 @@ import asyncio import json +import os import shutil -from dataclasses import asdict from pathlib import Path +from textwrap import dedent from typing import Sequence +from unittest.mock import patch import pytest @@ -108,3 +110,16 @@ async def test_cancel(tmp_path: Path, realization): assert (tmp_path / "a").exists() assert not (tmp_path / "b").exists() + + +async def test_that_max_submit_was_reached(tmp_path: Path, realization): + script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1" + step = create_bash_step(script) + realization.forward_models = [step] + sch = scheduler.Scheduler() + sch._max_submit = 3 + sch.add_realization(realization, callback_timeout=lambda _: None) + create_jobs_json(tmp_path, [step]) + sch.add_dispatch_information_to_jobs_file() + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED + assert (tmp_path / "cnt").read_text() == "3\n"