Skip to content

Commit

Permalink
Add Scheduler as alternative to JobQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah committed Dec 7, 2023
1 parent 6c87f5c commit 51c2b7b
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from cloudevents.http.event import CloudEvent

from ert.async_utils import get_event_loop
from ert.config.parsing.queue_system import QueueSystem
from ert.ensemble_evaluator import identifiers
from ert.job_queue import JobQueue
from ert.scheduler.scheduler import Scheduler
from ert.shared.feature_toggling import FeatureToggling

from .._wait_for_evaluator import wait_for_evaluator
from ._ensemble import Ensemble
Expand Down Expand Up @@ -41,7 +44,13 @@ def __init__(
super().__init__(reals, metadata, id_)
if not queue_config:
raise ValueError(f"{self} needs queue_config")
self._job_queue = JobQueue(queue_config)

if FeatureToggling.is_enabled("scheduler"):
if queue_config.queue_system != QueueSystem.LOCAL:
raise NotImplementedError()
self._job_queue = Scheduler()
else:
self._job_queue = JobQueue(queue_config)
self.stop_long_running = stop_long_running
self.min_required_realizations = min_required_realizations
self._config: Optional[EvaluatorServerConfig] = None
Expand Down
Empty file added src/ert/scheduler/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

import asyncio
import os
import sys
from enum import Enum
from typing import TYPE_CHECKING, Optional

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent

from ert.callbacks import forward_model_ok
from ert.job_queue.queue import _queue_state_event_type
from ert.load_status import LoadStatus

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler


class State(str, Enum):
WAITING = "WAITING"
SUBMITTING = "SUBMITTING"
STARTING = "STARTING"
RUNNING = "RUNNING"
ABORTING = "ABORTING"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
ABORTED = "ABORTED"


STATE_TO_LEGACY = {
State.WAITING: "WAITING",
State.SUBMITTING: "SUBMITTED",
State.STARTING: "PENDING",
State.RUNNING: "RUNNING",
State.ABORTING: "DO_KILL",
State.COMPLETED: "SUCCESS",
State.FAILED: "FAILED",
State.ABORTED: "IS_KILLED",
}


class Job:
"""Handle to a single job scheduler job.
Instances of this class represent a single job as submitted to a job scheduler
(LSF, PBS, SLURM, etc.)
"""

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
self.real = real
self._scheduler = scheduler

@property
def iens(self) -> int:
return self.real.iens

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
) -> None:
await start.wait()
await sem.acquire()

proc: Optional[asyncio.subprocess.Process] = None
try:
await self._send(State.SUBMITTING)

proc = await asyncio.create_subprocess_exec(
sys.executable,
self.real.job_script,
cwd=self.real.run_arg.runpath,
preexec_fn=os.setpgrp,
)
await self._send(State.STARTING)
await self._send(State.RUNNING)
returncode = await proc.wait()
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)

except asyncio.CancelledError:
await self._send(State.ABORTING)
if proc:
proc.kill()
await self._send(State.ABORTED)
finally:
sem.release()

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
event = CloudEvent(
{
"type": _queue_state_event_type(status),
"source": f"/etc/ensemble/{self._scheduler._ens_id}/real/{self.iens}",
"datacontenttype": "application/json",
},
{
"queue_event_type": status,
},
)
await self._scheduler._events.put(to_json(event))
138 changes: 138 additions & 0 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

import asyncio
import json
import logging
import os
import ssl
import threading
from asyncio.queues import Queue
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
Optional,
)

from pydantic.dataclasses import dataclass
from websockets import Headers
from websockets.client import connect

from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.scheduler.job import Job

if TYPE_CHECKING:
from ert.ensemble_evaluator._builder._realization import Realization


logger = logging.getLogger(__name__)


@dataclass
class _JobsJson:
ens_id: str
real_id: str
dispatch_url: str
ee_token: Optional[str]
ee_cert_path: Optional[str]
experiment_id: str


class Scheduler:
def __init__(self) -> None:
self._realizations: Dict[int, Job] = {}
self._tasks: Dict[int, asyncio.Task[None]] = {}
self._events: Queue[Any] = Queue()

self._ee_uri = ""
self._ens_id = ""
self._ee_cert: Optional[str] = None
self._ee_token: Optional[str] = None

def add_realization(
self, real: Realization, callback_timeout: Callable[[int], None]
) -> None:
self._realizations[real.iens] = Job(self, real)

def kill_all_jobs(self) -> None:
for task in self._tasks.values():
task.cancel()

def stop_long_running_jobs(self, minimum_required_realizations: int) -> None:
pass

def set_ee_info(
self, ee_uri: str, ens_id: str, ee_cert: Optional[str], ee_token: Optional[str]
) -> None:
self._ee_uri = ee_uri
self._ens_id = ens_id
self._ee_cert = ee_cert
self._ee_token = ee_token

async def _publisher(self) -> None:
tls: Optional[ssl.SSLContext] = None
if self._ee_cert:
tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
tls.load_verify_locations(cadata=self._ee_cert)
headers = Headers()
if self._ee_token:
headers["token"] = self._ee_token

async with connect(
self._ee_uri,
ssl=tls,
extra_headers=headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
) as conn:
while True:
event = await self._events.get()
await conn.send(event)

def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._realizations.values():
self._update_jobs_json(job.iens, job.real.run_arg.runpath)

async def execute(
self,
semaphore: Optional[threading.BoundedSemaphore] = None,
queue_evaluators: Optional[Iterable[Callable[..., Any]]] = None,
) -> str:
if queue_evaluators is not None:
logger.warning(f"Ignoring queue_evaluators: {queue_evaluators}")

publisher_task = asyncio.create_task(self._publisher())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
for iens, job in self._realizations.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))

start.set()
for task in self._tasks.values():
await task

publisher_task.cancel()

return EVTYPE_ENSEMBLE_STOPPED

def _update_jobs_json(self, iens: int, runpath: str) -> None:
jobs = _JobsJson(
experiment_id="_",
ens_id=self._ens_id,
real_id=str(iens),
dispatch_url=self._ee_uri,
ee_token=self._ee_token,
ee_cert_path=self._ee_cert,
)
jobs_path = os.path.join(runpath, "jobs.json")
with open(jobs_path, "r") as fp:
data = json.load(fp)
with open(jobs_path, "w") as fp:
data.update(asdict(jobs))
json.dump(data, fp)
4 changes: 4 additions & 0 deletions src/ert/shared/feature_toggling.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class FeatureToggling:
"Thank you for testing our new features."
),
),
"scheduler": _Feature(
default_enabled=False,
msg="Use Scheduler instead of JobQueue",
),
}

_conf = deepcopy(_conf_original)
Expand Down
110 changes: 110 additions & 0 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import asyncio
import json
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Sequence

import pytest

from ert.config.forward_model import ForwardModel
from ert.ensemble_evaluator._builder._realization import Realization
from ert.job_queue.queue import EVTYPE_ENSEMBLE_STOPPED
from ert.run_arg import RunArg
from ert.scheduler import scheduler


def create_bash_step(script: str) -> ForwardModel:
return ForwardModel(
name="bash_step",
executable="/usr/bin/env",
arglist=["bash", "-c", script],
)


def create_jobs_json(path: Path, steps: Sequence[ForwardModel]) -> None:
jobs = {
"global_environment": {},
"config_path": "/dev/null",
"config_file": "/dev/null",
"jobList": [
{
"name": step.name,
"executable": step.executable,
"argList": step.arglist,
}
for step in steps
],
"run_id": "0",
"ert_pid": "0",
"real_id": "0",
}

with open(path / "jobs.json", "w") as f:
json.dump(jobs, f)


@pytest.fixture
def realization(storage, tmp_path):
ensemble = storage.create_experiment().create_ensemble(name="foo", ensemble_size=1)

run_arg = RunArg(
run_id="",
ensemble_storage=ensemble,
iens=0,
itr=0,
runpath=str(tmp_path),
job_name="",
)

return Realization(
iens=0,
forward_models=[],
active=True,
max_runtime=None,
run_arg=run_arg,
num_cpu=1,
job_script=str(shutil.which("job_dispatch.py")),
)


async def test_empty():
sch = scheduler.Scheduler()
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED


async def test_single_job(tmp_path: Path, realization):
step = create_bash_step("echo 'Hello, world!' > testfile")
realization.forward_models = [step]

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()

assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "testfile").read_text() == "Hello, world!\n"


async def test_cancel(tmp_path: Path, realization):
step = create_bash_step("touch a; sleep 10; touch b")
realization.forward_models = [step]

sch = scheduler.Scheduler()
sch.add_realization(realization, callback_timeout=lambda _: None)

create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()

scheduler_task = asyncio.create_task(sch.execute())

# Wait for the job to start
await asyncio.sleep(1)

# Kill all jobs and wait for the scheduler to complete
sch.kill_all_jobs()
await scheduler_task

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()

0 comments on commit 51c2b7b

Please sign in to comment.