Skip to content

Commit

Permalink
Simplify JobQueue Driver options
Browse files Browse the repository at this point in the history
  • Loading branch information
pinkwah committed Nov 28, 2023
1 parent c784970 commit e2cd681
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 81 deletions.
16 changes: 13 additions & 3 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import threading
import uuid
from functools import partial, partialmethod
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
)

from cloudevents.http.event import CloudEvent

Expand Down Expand Up @@ -183,7 +193,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
# something is long running, the evaluator will know and should send
# commands to the task in order to have it killed/retried.
# See https://github.com/equinor/ert/issues/1229
queue_evaluators = None
queue_evaluators: Optional[Sequence[Callable[[], None]]] = None
if (
self._analysis_config.stop_long_running
and self._analysis_config.minimum_required_realizations > 0
Expand All @@ -208,7 +218,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches

result: str = await self._job_queue.execute(
queue_evaluators,
) # type: ignore
)
print(result)
except Exception as exc:
print(exc)
Expand Down
53 changes: 18 additions & 35 deletions src/ert/job_queue/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Dict,
List,
Optional,
Sequence,
Tuple,
)

Expand All @@ -28,22 +29,9 @@
class Driver(ABC):
def __init__(
self,
options: Optional[List[Tuple[str, str]]] = None,
options: Optional[Sequence[Tuple[str, str]]] = None,
):
self._options: Dict[str, str] = {}

if options:
for key, value in options:
self.set_option(key, value)

def set_option(self, option: str, value: str) -> None:
self._options.update({option: value})

def get_option(self, option_key: str) -> str:
return self._options[option_key]

def has_option(self, option_key: str) -> bool:
return option_key in self._options
self.options: Dict[str, str] = dict(options or [])

@abstractmethod
async def submit(self, realization: "RealizationState") -> None:
Expand Down Expand Up @@ -152,7 +140,7 @@ def __init__(self, queue_options: Optional[List[Tuple[str, str]]]):
self._currently_polling = False

async def run_with_retries(
self, func: Callable[[Any], Awaitable[Any]], error_msg: str = ""
self, func: Callable[[], Awaitable[Any]], error_msg: str = ""
) -> None:
current_attempt = 0
while current_attempt < self._max_attempt:
Expand All @@ -169,12 +157,10 @@ async def run_with_retries(

async def submit(self, realization: "RealizationState") -> None:
submit_cmd = self.build_submit_cmd(
[
"-J",
f"poly_{realization.realization.run_arg.iens}",
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
]
"-J",
f"poly_{realization.realization.run_arg.iens}",
str(realization.realization.job_script),
str(realization.realization.run_arg.runpath),
)
await self.run_with_retries(
lambda: self._submit(submit_cmd, realization=realization),
Expand Down Expand Up @@ -203,17 +189,15 @@ async def _submit(
logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}")
return True

def build_submit_cmd(self, args: List[str]) -> List[str]:
submit_cmd = [
self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub"
]
if self.has_option("LSF_QUEUE"):
submit_cmd += ["-q", self.get_option("LSF_QUEUE")]
def build_submit_cmd(self, *args: str) -> List[str]:
submit_cmd = [self.options.get("BSUB_CMD", "bsub")]
if (lsf_queue := self.options.get("LSF_QUEUE")) is not None:
submit_cmd += ["-q", lsf_queue]

return submit_cmd + args
return [*submit_cmd, *args]

async def run_shell_command(
self, command_to_run: List[str], command_name: str=""
self, command_to_run: List[str], command_name: str = ""
) -> Optional[Tuple[asyncio.subprocess.Process, bytes, bytes]]:
process = await asyncio.create_subprocess_exec(
*command_to_run,
Expand Down Expand Up @@ -243,10 +227,9 @@ async def poll_statuses(self) -> None:
return

poll_cmd = [
str(self.get_option("BJOBS_CMD"))
if self.has_option("BJOBS_CMD")
else "bjobs"
] + list(self._realstate_to_lsfid.values())
self.options.get("BJOBS_CMD", "bjobs"),
*self._realstate_to_lsfid.values(),
]
try:
await self.run_with_retries(lambda: self._poll_statuses(poll_cmd))
# suppress runtime error
Expand Down Expand Up @@ -309,7 +292,7 @@ async def kill(self, realization: "RealizationState") -> None:
lsf_job_id = self._realstate_to_lsfid[realization]
logger.debug(f"Attempting to kill {lsf_job_id=}")
kill_cmd = [
self.get_option("BKILL_CMD") if self.has_option("BKILL_CMD") else "bkill",
self.options.get("BKILL_CMD", "bkill"),
lsf_job_id,
]
await self.run_with_retries(
Expand Down
8 changes: 4 additions & 4 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import ssl
from collections import deque
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Sequence, Union

from cloudevents.conversion import to_json
from cloudevents.http import CloudEvent
Expand Down Expand Up @@ -186,8 +186,8 @@ def _add_realization(self, realization: QueueableRealization) -> int:

def max_running(self) -> int:
max_running = 0
if self.driver.has_option("MAX_RUNNING"):
max_running = int(self.driver.get_option("MAX_RUNNING"))
if (value := self.driver.options.get("MAX_RUNNING")) is not None:
max_running = int(value)
if max_running == 0:
return len(self._realizations)
return max_running
Expand Down Expand Up @@ -308,7 +308,7 @@ async def _realization_statechange_publisher(self) -> None:

async def execute(
self,
evaluators: Optional[List[Callable[..., Any]]] = None,
evaluators: Optional[Sequence[Callable[[], None]]] = None,
) -> str:
if evaluators is None:
evaluators = []
Expand Down
11 changes: 4 additions & 7 deletions src/ert/job_queue/realization_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,10 @@ def on_enter_EXIT(self) -> None:
)
if exit_file_path.exists():
exit_file = etree.parse(exit_file_path)
failed_job = exit_file.find("job").text
error_reason = exit_file.find("reason").text
stderr_capture = exit_file.find("stderr").text

stderr_file = ""
if stderr_file_node := exit_file.find("stderr_file"):
stderr_file = stderr_file_node.text
failed_job = exit_file.findtext("job", default="")
error_reason = exit_file.findtext("reason", default="")
stderr_capture = exit_file.findtext("stderr", default="")
stderr_file = exit_file.findtext("stderr_file", default="")

logger.error(
f"job {failed_job} failed with: '{error_reason}'\n"
Expand Down
6 changes: 3 additions & 3 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from threading import Thread
from time import sleep
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Tuple

import numpy as np

Expand Down Expand Up @@ -61,7 +61,7 @@ def _run_forward_model(
ert.ert_config.preferred_num_cpu,
)

queue_evaluators = None
queue_evaluators: Optional[Sequence[Callable[[], None]]] = None
if (
ert.ert_config.analysis_config.stop_long_running
and ert.ert_config.analysis_config.minimum_required_realizations > 0
Expand All @@ -73,7 +73,7 @@ def _run_forward_model(
)
]

asyncio.run(job_queue.execute(evaluators=queue_evaluators)) # type: ignore
asyncio.run(job_queue.execute(queue_evaluators))

run_context.sim_fs.sync()

Expand Down
31 changes: 2 additions & 29 deletions tests/unit_tests/job_queue/_test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,6 @@
from ert.job_queue import Driver


@pytest.mark.xfail(reason="Needs reimplementation")
def test_set_and_unset_option():
queue_config = QueueConfig(
job_script="script.sh",
queue_system=QueueSystem.LOCAL,
max_submit=2,
queue_options={
QueueSystem.LOCAL: [
("MAX_RUNNING", "50"),
("MAX_RUNNING", ""),
]
},
)
driver = Driver.create_driver(queue_config)
assert driver.get_option("MAX_RUNNING") == "0"
assert driver.set_option("MAX_RUNNING", "42")
assert driver.get_option("MAX_RUNNING") == "42"
driver.set_option("MAX_RUNNING", "")
assert driver.get_option("MAX_RUNNING") == "0"
driver.set_option("MAX_RUNNING", "100")
assert driver.get_option("MAX_RUNNING") == "100"
driver.set_option("MAX_RUNNING", "0")
assert driver.get_option("MAX_RUNNING") == "0"


@pytest.mark.xfail(reason="Needs reimplementation")
def test_get_driver_name():
queue_config = QueueConfig(queue_system=QueueSystem.LOCAL)
Expand Down Expand Up @@ -61,8 +36,6 @@ def test_get_slurm_queue_config():
assert queue_config.queue_system == QueueSystem.SLURM
driver = Driver.create_driver(queue_config)

assert driver.get_option("SBATCH") == "/path/to/sbatch"
assert driver.get_option("SCONTROL") == "scontrol"
driver.set_option("SCONTROL", "")
assert driver.get_option("SCONTROL") == ""
assert driver.options["SBATCH"] == "/path/to/sbatch"
assert driver.options["SCONTROL"] == "scontrol"
assert driver.name == "SLURM"

0 comments on commit e2cd681

Please sign in to comment.