Skip to content

Commit

Permalink
Typing src/ert/job_queue
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Nov 21, 2023
1 parent 01d351f commit 4ca8a66
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 46 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,8 @@ ignore_missing_imports = True
[mypy-ruamel]
ignore_missing_imports = True

[mypy-statemachine]
ignore_missing_imports = True

[mypy-ert.callbacks]
ignore_errors = True
37 changes: 17 additions & 20 deletions src/ert/job_queue/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

if TYPE_CHECKING:
from ert.config import QueueConfig
from ert.job_queue import QueueableRealization, RealizationState
from ert.job_queue import RealizationState


class Driver(ABC):
def __init__(
self,
options: Optional[List[Tuple[str, str]]] = None,
):
self._options = {}
self._options: Dict[str, str] = {}

if options:
for key, value in options:
Expand All @@ -32,15 +32,15 @@ def has_option(self, option_key: str) -> bool:
return option_key in self._options

@abstractmethod
async def submit(self, realization: "RealizationState"):
async def submit(self, realization: "RealizationState") -> None:
pass

@abstractmethod
async def poll_statuses(self):
async def poll_statuses(self) -> None:
pass

@abstractmethod
async def kill(self, realization: "RealizationState"):
async def kill(self, realization: "RealizationState") -> None:
pass

@classmethod
Expand All @@ -60,10 +60,10 @@ def __init__(self, queue_config: List[Tuple[str, str]]):
self._currently_polling = False

@property
def optionnames(self):
def optionnames(self) -> List[str]:
return []

async def submit(self, realization: "RealizationState"):
async def submit(self, realization: "RealizationState") -> None:
"""Submit and *actually (a)wait* for the process to finish."""
realization.accept()
try:
Expand Down Expand Up @@ -93,16 +93,16 @@ async def submit(self, realization: "RealizationState"):
realization.runfail()
# TODO: fetch stdout/stderr

async def poll_statuses(self):
async def poll_statuses(self) -> None:
pass

async def kill(self, realization: "RealizationState"):
async def kill(self, realization: "RealizationState") -> None:
self._processes[realization].kill()
realization.verify_kill()


class LSFDriver(Driver):
def __init__(self, queue_options):
def __init__(self, queue_options: Optional[List[Tuple[str, str]]]):
super().__init__(queue_options)

self._realstate_to_lsfid: Dict["RealizationState", str] = {}
Expand All @@ -113,15 +113,14 @@ def __init__(self, queue_options):

self._currently_polling = False

async def submit(self, realization: "RealizationState"):
submit_cmd = [
async def submit(self, realization: "RealizationState") -> None:
submit_cmd: List[str] = [
"bsub",
"-J",
f"poly_{realization.realization.run_arg.iens}",
realization.realization.job_script,
realization.realization.run_arg.runpath,
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
]
assert shutil.which(submit_cmd[0]) # does not propagate back..
process = await asyncio.create_subprocess_exec(
*submit_cmd,
stdout=asyncio.subprocess.PIPE,
Expand All @@ -142,13 +141,11 @@ async def submit(self, realization: "RealizationState"):
print(f"Submitted job {realization} and got LSF JOBID {lsf_id}")
except Exception:
# We should probably retry the submission, bsub stdout seems flaky.
print(f"ERROR: Could not parse lsf id from: {output}")
print(f"ERROR: Could not parse lsf id from: {output!r}")

async def poll_statuses(self) -> None:
if self._currently_polling:
# Don't repeat if we are called too often.
# So easy in async..
return self._statuses
return
self._currently_polling = True

if not self._realstate_to_lsfid:
Expand Down Expand Up @@ -198,6 +195,6 @@ async def poll_statuses(self) -> None:

self._currently_polling = False

async def kill(self, realization):
async def kill(self, realization: "RealizationState") -> None:
print(f"would like to kill {realization}")
pass
16 changes: 10 additions & 6 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def realization_state(self, iens: int) -> RealizationState:
def count_realization_state(self, state: RealizationState) -> int:
return sum(real.current_state == state for real in self._realizations)

async def run_done_callback(self, state: RealizationState):
async def run_done_callback(self, state: RealizationState) -> Optional[LoadStatus]:
callback_status, status_msg = forward_model_ok(state.realization.run_arg)
if callback_status == LoadStatus.LOAD_SUCCESSFUL:
state.validate()
# todo: implement me
return None

@property
def stopped(self) -> bool:
Expand Down Expand Up @@ -162,7 +164,9 @@ def queue_size(self) -> int:

def _add_realization(self, realization: QueueableRealization) -> int:
self._realizations.append(
RealizationState(self, realization, retries=self._queue_config.max_submit - 1)
RealizationState(
self, realization, retries=self._queue_config.max_submit - 1
)
)
return len(self._realizations) - 1

Expand Down Expand Up @@ -311,9 +315,10 @@ async def execute(
await self.driver.poll_statuses()

for real in self._realizations:
if real.realization.max_runtime is None:
continue
if (
real.realization.max_runtime != None
and real.current_state == RealizationState.RUNNING
real.current_state == RealizationState.RUNNING
and real.start_time
and datetime.datetime.now() - real.start_time
> datetime.timedelta(seconds=real.realization.max_runtime)
Expand Down Expand Up @@ -353,7 +358,6 @@ async def execute(

return EVTYPE_ENSEMBLE_STOPPED


def add_realization_from_run_arg(
self,
run_arg: "RunArg",
Expand Down Expand Up @@ -412,7 +416,7 @@ def stop_long_running_realizations(
sum(real.runtime for real in completed) / finished_realizations
)

for job in self.job_list:
for job in self.job_list: # type: ignore
if job.runtime > LONG_RUNNING_FACTOR * average_runtime:
job.stop()

Expand Down
21 changes: 7 additions & 14 deletions src/ert/job_queue/realization_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,8 @@ class QueueableRealization: # Aka "Job" or previously "JobQueueNode"
max_runtime: Optional[int] = None
callback_timeout: Optional[Callable[[int], None]] = None

def __hash__(self):
# Elevate iens up to two levels? Check if it can be removed from run_arg
return self.run_arg.iens

def __repr__(self):
return str(self.run_arg.iens)


class RealizationState(StateMachine):
class RealizationState(StateMachine): # type: ignore
NOT_ACTIVE = State("NOT ACTIVE")
WAITING = State("WAITING", initial=True)
SUBMITTED = State("SUBMITTED")
Expand Down Expand Up @@ -102,7 +95,7 @@ def __init__(

donotgohere = UNKNOWN.to(STATUS_FAILURE)

def on_enter_state(self, target, event):
def on_enter_state(self, target: RealizationState) -> None:
if self.jobqueue._changes_to_publish is None:
return
if target in (
Expand All @@ -116,21 +109,21 @@ def on_enter_state(self, target, event):
change = {self.realization.run_arg.iens: target.id}
asyncio.create_task(self.jobqueue._changes_to_publish.put(change))

def on_enter_SUBMITTED(self):
def on_enter_SUBMITTED(self) -> None:
asyncio.create_task(self.jobqueue.driver.submit(self))

def on_enter_RUNNING(self):
def on_enter_RUNNING(self) -> None:
self.start_time = datetime.datetime.now()

def on_enter_EXIT(self):
def on_enter_EXIT(self) -> None:
if self.retries_left > 0:
self.retry()
self.retries_left -= 1
else:
self.invalidate()

def on_enter_DONE(self):
def on_enter_DONE(self) -> None:
asyncio.create_task(self.jobqueue.run_done_callback(self))

def on_enter_DO_KILL(self):
def on_enter_DO_KILL(self) -> None:
asyncio.create_task(self.jobqueue.driver.kill(self))
15 changes: 9 additions & 6 deletions tests/unit_tests/job_queue/test_job_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,12 @@ async def test_max_submit(tmpdir, monkeypatch, failing_script, max_submit_num):
await job_queue.stop_jobs_async()
await asyncio.gather(execute_task)


@pytest.mark.asyncio
@pytest.mark.parametrize("max_submit_num", [1, 3])
async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, monkeypatch, simple_script):
async def test_that_kill_queue_disregards_max_submit(
tmpdir, max_submit_num, monkeypatch, simple_script
):
monkeypatch.chdir(tmpdir)
job_queue = create_local_queue(simple_script, max_submit=max_submit_num)
await job_queue.stop_jobs_async()
Expand All @@ -140,7 +143,11 @@ async def test_that_kill_queue_disregards_max_submit(tmpdir, max_submit_num, mon
print(tmpdir)
for iens in range(job_queue.queue_size):
assert not Path(f"dummy_path_{iens}/STATUS").exists()
assert job_queue.count_realization_state(RealizationState.IS_KILLED) == job_queue.queue_size
assert (
job_queue.count_realization_state(RealizationState.IS_KILLED)
== job_queue.queue_size
)


@pytest.mark.asyncio
@pytest.mark.timeout(5)
Expand Down Expand Up @@ -231,9 +238,6 @@ def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script):
assert not (runpath / cert_file).exists()





@pytest.mark.skip(reason="Needs reimplementation")
def test_stop_long_running():
"""
Expand Down Expand Up @@ -283,7 +287,6 @@ def test_stop_long_running():
assert queue.snapshot()[i] == str(JobStatus.RUNNING)



@pytest.mark.usefixtures("use_tmpdir", "mock_fm_ok")
@pytest.mark.skip(reason="Needs reimplementation")
def test_num_cpu_submitted_correctly_lsf(tmpdir, simple_script):
Expand Down

0 comments on commit 4ca8a66

Please sign in to comment.