-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Scheduler as alternative to JobQueue
- Loading branch information
Showing
6 changed files
with
369 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import os | ||
import sys | ||
from enum import Enum | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from cloudevents.conversion import to_json | ||
from cloudevents.http import CloudEvent | ||
|
||
from ert.callbacks import forward_model_ok | ||
from ert.job_queue.queue import _queue_state_event_type | ||
from ert.load_status import LoadStatus | ||
|
||
if TYPE_CHECKING: | ||
from ert.ensemble_evaluator._builder._realization import Realization | ||
from ert.scheduler.scheduler import Scheduler | ||
|
||
|
||
class State(str, Enum): | ||
WAITING = "WAITING" | ||
SUBMITTING = "SUBMITTING" | ||
STARTING = "STARTING" | ||
RUNNING = "RUNNING" | ||
ABORTING = "ABORTING" | ||
COMPLETED = "COMPLETED" | ||
FAILED = "FAILED" | ||
ABORTED = "ABORTED" | ||
|
||
|
||
STATE_TO_LEGACY = { | ||
State.WAITING: "WAITING", | ||
State.SUBMITTING: "SUBMITTED", | ||
State.STARTING: "PENDING", | ||
State.RUNNING: "RUNNING", | ||
State.ABORTING: "DO_KILL", | ||
State.COMPLETED: "SUCCESS", | ||
State.FAILED: "FAILED", | ||
State.ABORTED: "IS_KILLED", | ||
} | ||
|
||
|
||
class Job: | ||
"""Handle to a single job scheduler job. | ||
Instances of this class represent a single job as submitted to a job scheduler | ||
(LSF, PBS, SLURM, etc.) | ||
""" | ||
|
||
def __init__(self, scheduler: Scheduler, real: Realization) -> None: | ||
self.real = real | ||
self._scheduler = scheduler | ||
|
||
@property | ||
def iens(self) -> int: | ||
return self.real.iens | ||
|
||
async def __call__( | ||
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore | ||
) -> None: | ||
await start.wait() | ||
await sem.acquire() | ||
|
||
proc: Optional[asyncio.subprocess.Process] = None | ||
try: | ||
await self._send(State.SUBMITTING) | ||
|
||
proc = await asyncio.create_subprocess_exec( | ||
sys.executable, | ||
self.real.job_script, | ||
cwd=self.real.run_arg.runpath, | ||
preexec_fn=os.setpgrp, | ||
) | ||
await self._send(State.STARTING) | ||
await self._send(State.RUNNING) | ||
returncode = await proc.wait() | ||
if ( | ||
returncode == 0 | ||
and forward_model_ok(self.real.run_arg).status | ||
== LoadStatus.LOAD_SUCCESSFUL | ||
): | ||
await self._send(State.COMPLETED) | ||
else: | ||
await self._send(State.FAILED) | ||
|
||
except asyncio.CancelledError: | ||
await self._send(State.ABORTING) | ||
if proc: | ||
proc.kill() | ||
await self._send(State.ABORTED) | ||
finally: | ||
sem.release() | ||
|
||
async def _send(self, state: State) -> None: | ||
status = STATE_TO_LEGACY[state] | ||
event = CloudEvent( | ||
{ | ||
"type": _queue_state_event_type(status), | ||
"source": f"/etc/ensemble/{self._scheduler._ens_id}/real/{self.iens}", | ||
"datacontenttype": "application/json", | ||
}, | ||
{ | ||
"queue_event_type": status, | ||
}, | ||
) | ||
await self._scheduler._events.put(to_json(event)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import json | ||
import logging | ||
import os | ||
import ssl | ||
import threading | ||
from asyncio.queues import Queue | ||
from dataclasses import asdict | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Callable, | ||
Dict, | ||
Iterable, | ||
Optional, | ||
) | ||
|
||
from pydantic.dataclasses import dataclass | ||
from websockets import Headers | ||
from websockets.client import connect | ||
|
||
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED | ||
from ert.scheduler.job import Job | ||
|
||
if TYPE_CHECKING: | ||
from ert.ensemble_evaluator._builder._realization import Realization | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class _JobsJson: | ||
ens_id: str | ||
real_id: str | ||
dispatch_url: str | ||
ee_token: Optional[str] | ||
ee_cert_path: Optional[str] | ||
experiment_id: str | ||
|
||
|
||
class Scheduler: | ||
def __init__(self) -> None: | ||
self._realizations: Dict[int, Job] = {} | ||
self._tasks: Dict[int, asyncio.Task[None]] = {} | ||
self._events: Queue[Any] = Queue() | ||
|
||
self._ee_uri = "" | ||
self._ens_id = "" | ||
self._ee_cert: Optional[str] = None | ||
self._ee_token: Optional[str] = None | ||
|
||
def add_realization( | ||
self, real: Realization, callback_timeout: Callable[[int], None] | ||
) -> None: | ||
self._realizations[real.iens] = Job(self, real) | ||
|
||
def kill_all_jobs(self) -> None: | ||
for task in self._tasks.values(): | ||
task.cancel() | ||
|
||
def stop_long_running_jobs(self, minimum_required_realizations: int) -> None: | ||
pass | ||
|
||
def set_ee_info( | ||
self, ee_uri: str, ens_id: str, ee_cert: Optional[str], ee_token: Optional[str] | ||
) -> None: | ||
self._ee_uri = ee_uri | ||
self._ens_id = ens_id | ||
self._ee_cert = ee_cert | ||
self._ee_token = ee_token | ||
|
||
async def _publisher(self) -> None: | ||
tls: Optional[ssl.SSLContext] = None | ||
if self._ee_cert: | ||
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) | ||
tls.load_verify_locations(cadata=self._ee_cert) | ||
headers = Headers() | ||
if self._ee_token: | ||
headers["token"] = self._ee_token | ||
|
||
async with connect( | ||
self._ee_uri, | ||
ssl=tls, | ||
extra_headers=headers, | ||
open_timeout=60, | ||
ping_timeout=60, | ||
ping_interval=60, | ||
close_timeout=60, | ||
) as conn: | ||
while True: | ||
event = await self._events.get() | ||
await conn.send(event) | ||
|
||
def add_dispatch_information_to_jobs_file(self) -> None: | ||
for job in self._realizations.values(): | ||
self._update_jobs_json(job.iens, job.real.run_arg.runpath) | ||
|
||
async def execute( | ||
self, | ||
semaphore: Optional[threading.BoundedSemaphore] = None, | ||
queue_evaluators: Optional[Iterable[Callable[..., Any]]] = None, | ||
) -> str: | ||
if queue_evaluators is not None: | ||
logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}") | ||
|
||
publisher_task = asyncio.create_task(self._publisher()) | ||
|
||
start = asyncio.Event() | ||
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore | ||
for iens, job in self._realizations.items(): | ||
self._tasks[iens] = asyncio.create_task(job(start, sem)) | ||
|
||
start.set() | ||
for task in self._tasks.values(): | ||
await task | ||
|
||
publisher_task.cancel() | ||
|
||
return EVTYPE_ENSEMBLE_STOPPED | ||
|
||
def _update_jobs_json(self, iens: int, runpath: str) -> None: | ||
jobs = _JobsJson( | ||
experiment_id="_", | ||
ens_id=self._ens_id, | ||
real_id=str(iens), | ||
dispatch_url=self._ee_uri, | ||
ee_token=self._ee_token, | ||
ee_cert_path=self._ee_cert, | ||
) | ||
jobs_path = os.path.join(runpath, "jobs.json") | ||
with open(jobs_path, "r") as fp: | ||
data = json.load(fp) | ||
with open(jobs_path, "w") as fp: | ||
data.update(asdict(jobs)) | ||
json.dump(data, fp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import asyncio | ||
import json | ||
import shutil | ||
from dataclasses import asdict | ||
from pathlib import Path | ||
from typing import Sequence | ||
|
||
import pytest | ||
|
||
from ert.config.forward_model import ForwardModel | ||
from ert.ensemble_evaluator._builder._realization import Realization | ||
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED | ||
from ert.run_arg import RunArg | ||
from ert.scheduler import scheduler | ||
|
||
|
||
def create_bash_step(script: str) -> ForwardModel: | ||
return ForwardModel( | ||
name="bash_step", | ||
executable="/usr/bin/env", | ||
arglist=["bash", "-c", script], | ||
) | ||
|
||
|
||
def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None: | ||
jobs = { | ||
"global_environment": {}, | ||
"config_path": "/dev/null", | ||
"config_file": "/dev/null", | ||
"jobList": [ | ||
{ | ||
"name": step.name, | ||
"executable": step.executable, | ||
"argList": step.arglist, | ||
} | ||
for step in steps | ||
], | ||
"run_id": "0", | ||
"ert_pid": "0", | ||
"real_id": "0", | ||
} | ||
|
||
with open(path / "jobs.json", "w") as f: | ||
json.dump(jobs, f) | ||
|
||
|
||
@pytest.fixture | ||
def realization(storage, tmp_path): | ||
ensemble = storage.create_experiment().create_ensemble(name="foo", ensemble_size=1) | ||
|
||
run_arg = RunArg( | ||
run_id="", | ||
ensemble_storage=ensemble, | ||
iens=0, | ||
itr=0, | ||
runpath=str(tmp_path), | ||
job_name="", | ||
) | ||
|
||
return Realization( | ||
iens=0, | ||
forward_models=[], | ||
active=True, | ||
max_runtime=None, | ||
run_arg=run_arg, | ||
num_cpu=1, | ||
job_script=str(shutil.which("job_dispatch.py")), | ||
) | ||
|
||
|
||
async def test_empty(): | ||
sch = scheduler.Scheduler() | ||
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED | ||
|
||
|
||
async def test_single_job(tmp_path: Path, realization): | ||
step = create_bash_step("echo 'Hello, world!' > testfile") | ||
realization.forward_models = [step] | ||
|
||
sch = scheduler.Scheduler() | ||
sch.add_realization(realization, callback_timeout=lambda _: None) | ||
|
||
create_jobs_json(tmp_path, [step]) | ||
sch.add_dispatch_information_to_jobs_file() | ||
|
||
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED | ||
assert (tmp_path / "testfile").read_text() == "Hello, world!\n" | ||
|
||
|
||
async def test_cancel(tmp_path: Path, realization): | ||
step = create_bash_step("touch a; sleep 10; touch b") | ||
realization.forward_models = [step] | ||
|
||
sch = scheduler.Scheduler() | ||
sch.add_realization(realization, callback_timeout=lambda _: None) | ||
|
||
create_jobs_json(tmp_path, [step]) | ||
sch.add_dispatch_information_to_jobs_file() | ||
|
||
scheduler_task = asyncio.create_task(sch.execute()) | ||
|
||
# Wait for the job to start | ||
await asyncio.sleep(1) | ||
|
||
# Kill all jobs and wait for the scheduler to complete | ||
sch.kill_all_jobs() | ||
await scheduler_task | ||
|
||
assert (tmp_path / "a").exists() | ||
assert not (tmp_path / "b").exists() |