From 51c2b7b0424cc68f5358cd87e17d4883ff7daec5 Mon Sep 17 00:00:00 2001 From: Zohar Malamant Date: Mon, 4 Dec 2023 12:15:14 +0100 Subject: [PATCH] Add Scheduler as alternative to JobQueue --- .../ensemble_evaluator/_builder/_legacy.py | 11 +- src/ert/scheduler/__init__.py | 0 src/ert/scheduler/job.py | 107 ++++++++++++++ src/ert/scheduler/scheduler.py | 138 ++++++++++++++++++ src/ert/shared/feature_toggling.py | 4 + tests/unit_tests/scheduler/test_scheduler.py | 110 ++++++++++++++ 6 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 src/ert/scheduler/__init__.py create mode 100644 src/ert/scheduler/job.py create mode 100644 src/ert/scheduler/scheduler.py create mode 100644 tests/unit_tests/scheduler/test_scheduler.py diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index e8e83c63b8b..5ccdc03ac50 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -10,8 +10,11 @@ from cloudevents.http.event import CloudEvent from ert.async_utils import get_event_loop +from ert.config.parsing.queue_system import QueueSystem from ert.ensemble_evaluator import identifiers from ert.job_queue import JobQueue +from ert.scheduler.scheduler import Scheduler +from ert.shared.feature_toggling import FeatureToggling from .._wait_for_evaluator import wait_for_evaluator from ._ensemble import Ensemble @@ -41,7 +44,13 @@ def __init__( super().__init__(reals, metadata, id_) if not queue_config: raise ValueError(f"{self} needs queue_config") - self._job_queue = JobQueue(queue_config) + + if FeatureToggling.is_enabled("scheduler"): + if queue_config.queue_system != QueueSystem.LOCAL: + raise NotImplementedError() + self._job_queue = Scheduler() + else: + self._job_queue = JobQueue(queue_config) self.stop_long_running = stop_long_running self.min_required_realizations = min_required_realizations self._config: Optional[EvaluatorServerConfig] = None diff --git a/src/ert/scheduler/__init__.py b/src/ert/scheduler/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py new file mode 100644 index 00000000000..0cc83943100 --- /dev/null +++ b/src/ert/scheduler/job.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import asyncio +import os +import sys +from enum import Enum +from typing import TYPE_CHECKING, Optional + +from cloudevents.conversion import to_json +from cloudevents.http import CloudEvent + +from ert.callbacks import forward_model_ok +from ert.job_queue.queue import _queue_state_event_type +from ert.load_status import LoadStatus + +if TYPE_CHECKING: + from ert.ensemble_evaluator._builder._realization import Realization + from ert.scheduler.scheduler import Scheduler + + +class State(str, Enum): + WAITING = "WAITING" + SUBMITTING = "SUBMITTING" + STARTING = "STARTING" + RUNNING = "RUNNING" + ABORTING = "ABORTING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + ABORTED = "ABORTED" + + +STATE_TO_LEGACY = { + State.WAITING: "WAITING", + State.SUBMITTING: "SUBMITTED", + State.STARTING: "PENDING", + State.RUNNING: "RUNNING", + State.ABORTING: "DO_KILL", + State.COMPLETED: "SUCCESS", + State.FAILED: "FAILED", + State.ABORTED: "IS_KILLED", +} + + +class Job: + """Handle to a single job scheduler job. + + Instances of this class represent a single job as submitted to a job scheduler + (LSF, PBS, SLURM, etc.) + """ + + def __init__(self, scheduler: Scheduler, real: Realization) -> None: + self.real = real + self._scheduler = scheduler + + @property + def iens(self) -> int: + return self.real.iens + + async def __call__( + self, start: asyncio.Event, sem: asyncio.BoundedSemaphore + ) -> None: + await start.wait() + await sem.acquire() + + proc: Optional[asyncio.subprocess.Process] = None + try: + await self._send(State.SUBMITTING) + + proc = await asyncio.create_subprocess_exec( + sys.executable, + self.real.job_script, + cwd=self.real.run_arg.runpath, + preexec_fn=os.setpgrp, + ) + await self._send(State.STARTING) + await self._send(State.RUNNING) + returncode = await proc.wait() + if ( + returncode == 0 + and forward_model_ok(self.real.run_arg).status + == LoadStatus.LOAD_SUCCESSFUL + ): + await self._send(State.COMPLETED) + else: + await self._send(State.FAILED) + + except asyncio.CancelledError: + await self._send(State.ABORTING) + if proc: + proc.kill() + await self._send(State.ABORTED) + finally: + sem.release() + + async def _send(self, state: State) -> None: + status = STATE_TO_LEGACY[state] + event = CloudEvent( + { + "type": _queue_state_event_type(status), + "source": f"/etc/ensemble/{self._scheduler._ens_id}/real/{self.iens}", + "datacontenttype": "application/json", + }, + { + "queue_event_type": status, + }, + ) + await self._scheduler._events.put(to_json(event)) diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py new file mode 100644 index 00000000000..e57259178db --- /dev/null +++ b/src/ert/scheduler/scheduler.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import asyncio +import json +import logging +import os +import ssl +import threading +from asyncio.queues import Queue +from dataclasses import asdict +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Optional, +) + +from pydantic.dataclasses import dataclass +from websockets import Headers +from websockets.client import connect + +from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED +from ert.scheduler.job import Job + +if TYPE_CHECKING: + from ert.ensemble_evaluator._builder._realization import Realization + + +logger = logging.getLogger(__name__) + + +@dataclass +class _JobsJson: + ens_id: str + real_id: str + dispatch_url: str + ee_token: Optional[str] + ee_cert_path: Optional[str] + experiment_id: str + + +class Scheduler: + def __init__(self) -> None: + self._realizations: Dict[int, Job] = {} + self._tasks: Dict[int, asyncio.Task[None]] = {} + self._events: Queue[Any] = Queue() + + self._ee_uri = "" + self._ens_id = "" + self._ee_cert: Optional[str] = None + self._ee_token: Optional[str] = None + + def add_realization( + self, real: Realization, callback_timeout: Callable[[int], None] + ) -> None: + self._realizations[real.iens] = Job(self, real) + + def kill_all_jobs(self) -> None: + for task in self._tasks.values(): + task.cancel() + + def stop_long_running_jobs(self, minimum_required_realizations: int) -> None: + pass + + def set_ee_info( + self, ee_uri: str, ens_id: str, ee_cert: Optional[str], ee_token: Optional[str] + ) -> None: + self._ee_uri = ee_uri + self._ens_id = ens_id + self._ee_cert = ee_cert + self._ee_token = ee_token + + async def _publisher(self) -> None: + tls: Optional[ssl.SSLContext] = None + if self._ee_cert: + tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + tls.load_verify_locations(cadata=self._ee_cert) + headers = Headers() + if self._ee_token: + headers["token"] = self._ee_token + + async with connect( + self._ee_uri, + ssl=tls, + extra_headers=headers, + open_timeout=60, + ping_timeout=60, + ping_interval=60, + close_timeout=60, + ) as conn: + while True: + event = await self._events.get() + await conn.send(event) + + def add_dispatch_information_to_jobs_file(self) -> None: + for job in self._realizations.values(): + self._update_jobs_json(job.iens, job.real.run_arg.runpath) + + async def execute( + self, + semaphore: Optional[threading.BoundedSemaphore] = None, + queue_evaluators: Optional[Iterable[Callable[..., Any]]] = None, + ) -> str: + if queue_evaluators is not None: + logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}") + + publisher_task = asyncio.create_task(self._publisher()) + + start = asyncio.Event() + sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore + for iens, job in self._realizations.items(): + self._tasks[iens] = asyncio.create_task(job(start, sem)) + + start.set() + for task in self._tasks.values(): + await task + + publisher_task.cancel() + + return EVTYPE_ENSEMBLE_STOPPED + + def _update_jobs_json(self, iens: int, runpath: str) -> None: + jobs = _JobsJson( + experiment_id="_", + ens_id=self._ens_id, + real_id=str(iens), + dispatch_url=self._ee_uri, + ee_token=self._ee_token, + ee_cert_path=self._ee_cert, + ) + jobs_path = os.path.join(runpath, "jobs.json") + with open(jobs_path, "r") as fp: + data = json.load(fp) + with open(jobs_path, "w") as fp: + data.update(asdict(jobs)) + json.dump(data, fp) diff --git a/src/ert/shared/feature_toggling.py b/src/ert/shared/feature_toggling.py index 116d0515b95..fa1724ef3dc 100644 --- a/src/ert/shared/feature_toggling.py +++ b/src/ert/shared/feature_toggling.py @@ -21,6 +21,10 @@ class FeatureToggling: "Thank you for testing our new features." ), ), + "scheduler": _Feature( + default_enabled=False, + msg="Use Scheduler instead of JobQueue", + ), } _conf = deepcopy(_conf_original) diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py new file mode 100644 index 00000000000..d9813b5e3ee --- /dev/null +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -0,0 +1,110 @@ +import asyncio +import json +import shutil +from dataclasses import asdict +from pathlib import Path +from typing import Sequence + +import pytest + +from ert.config.forward_model import ForwardModel +from ert.ensemble_evaluator._builder._realization import Realization +from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED +from ert.run_arg import RunArg +from ert.scheduler import scheduler + + +def create_bash_step(script: str) -> ForwardModel: + return ForwardModel( + name="bash_step", + executable="/usr/bin/env", + arglist=["bash", "-c", script], + ) + + +def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None: + jobs = { + "global_environment": {}, + "config_path": "/dev/null", + "config_file": "/dev/null", + "jobList": [ + { + "name": step.name, + "executable": step.executable, + "argList": step.arglist, + } + for step in steps + ], + "run_id": "0", + "ert_pid": "0", + "real_id": "0", + } + + with open(path / "jobs.json", "w") as f: + json.dump(jobs, f) + + +@pytest.fixture +def realization(storage, tmp_path): + ensemble = storage.create_experiment().create_ensemble(name="foo", ensemble_size=1) + + run_arg = RunArg( + run_id="", + ensemble_storage=ensemble, + iens=0, + itr=0, + runpath=str(tmp_path), + job_name="", + ) + + return Realization( + iens=0, + forward_models=[], + active=True, + max_runtime=None, + run_arg=run_arg, + num_cpu=1, + job_script=str(shutil.which("job_dispatch.py")), + ) + + +async def test_empty(): + sch = scheduler.Scheduler() + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED + + +async def test_single_job(tmp_path: Path, realization): + step = create_bash_step("echo 'Hello, world!' > testfile") + realization.forward_models = [step] + + sch = scheduler.Scheduler() + sch.add_realization(realization, callback_timeout=lambda _: None) + + create_jobs_json(tmp_path, [step]) + sch.add_dispatch_information_to_jobs_file() + + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED + assert (tmp_path / "testfile").read_text() == "Hello, world!\n" + + +async def test_cancel(tmp_path: Path, realization): + step = create_bash_step("touch a; sleep 10; touch b") + realization.forward_models = [step] + + sch = scheduler.Scheduler() + sch.add_realization(realization, callback_timeout=lambda _: None) + + create_jobs_json(tmp_path, [step]) + sch.add_dispatch_information_to_jobs_file() + + scheduler_task = asyncio.create_task(sch.execute()) + + # Wait for the job to start + await asyncio.sleep(1) + + # Kill all jobs and wait for the scheduler to complete + sch.kill_all_jobs() + await scheduler_task + + assert (tmp_path / "a").exists() + assert not (tmp_path / "b").exists()