diff --git a/src/ert/job_queue/driver.py b/src/ert/job_queue/driver.py index b2ff5138ec0..63e9068c5f1 100644 --- a/src/ert/job_queue/driver.py +++ b/src/ert/job_queue/driver.py @@ -1,9 +1,18 @@ import asyncio +import logging import os +import re import shlex -import shutil from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, +) from ert.config.parsing.queue_system import QueueSystem @@ -12,6 +21,9 @@ from ert.job_queue import RealizationState +logger = logging.getLogger(__name__) + + class Driver(ABC): def __init__( self, @@ -104,64 +116,144 @@ async def kill(self, realization: "RealizationState") -> None: class LSFDriver(Driver): + LSF_STATUSES = [ + "PEND", + "SSUSP", + "PSUSP", + "USUSP", + "RUN", + "EXIT", + "ZOMBI", + "DONE", + "PDONE", + "UNKWN", + ] + def __init__(self, queue_options: Optional[List[Tuple[str, str]]]): super().__init__(queue_options) self._realstate_to_lsfid: Dict["RealizationState", str] = {} self._lsfid_to_realstate: Dict[str, "RealizationState"] = {} + self._max_attempt = 100 + self._MAX_ERROR_COUNT = 100 + self._statuses = {} self._submit_processes: Dict[ "RealizationState", "asyncio.subprocess.Process" ] = {} - + self._retry_sleep_period = 3 self._currently_polling = False + async def run_with_retries( + self, func: Callable[[None], Awaitable], error_msg: str = "" + ): + current_attempt = 0 + while current_attempt < self._max_attempt: + current_attempt += 1 + try: + function_output = await func() + if function_output: + return function_output + await asyncio.sleep(self._retry_sleep_period) + except asyncio.CancelledError as e: + logger.error(e) + await asyncio.sleep(self._retry_sleep_period) + raise RuntimeError(error_msg) + async def submit(self, realization: "RealizationState") -> None: - submit_cmd: List[str] = [ - "bsub", + submit_cmd = self.parse_submit_cmd( "-J", f"poly_{realization.realization.run_arg.iens}", str(realization.realization.job_script), str(realization.realization.run_arg.runpath), + ) + await self.run_with_retries( + lambda: self._submit(submit_cmd, realization=realization), + error_msg="Maximum number of submit errors exceeded\n", + ) + + async def _submit( + self, submit_command: list[str], realization: "RealizationState" + ) -> bool: + result = await self.run_shell_command(submit_command, command_name="bsub") + if not result: + return False + + (process, output, error) = result + self._submit_processes[realization] = process + lsf_id_match = re.match( + "Job <\\d+> is submitted to \\w+ queue <\\w+>\\.", output.decode() + ) + if lsf_id_match is None: + logger.error(f"Could not parse lsf id from: {output.decode()}") + return False + lsf_id = lsf_id_match.group(0) + self._realstate_to_lsfid[realization] = lsf_id + self._lsfid_to_realstate[lsf_id] = realization + realization.accept() + logger.info(f"Submitted job {realization} and got LSF JOBID {lsf_id}") + return True + + def parse_submit_cmd(self, *additional_parameters) -> List[str]: + submit_cmd = [ + self.get_option("BSUB_CMD") if self.has_option("BSUB_CMD") else "bsub" ] + if self.has_option("LSF_QUEUE"): + submit_cmd += ["-q", self.get_option("LSF_QUEUE")] + + return submit_cmd + list(additional_parameters) + + async def run_shell_command( + self, command_to_run: list[str], command_name="" + ) -> (asyncio.subprocess.Process, str, str): process = await asyncio.create_subprocess_exec( - *submit_cmd, + *command_to_run, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) - self._submit_processes[realization] = process - - # Wait for submit process to finish: - output, error = await process.communicate() - print(output) # FLAKY ALERT, we seem to get empty - print(error) - - try: - lsf_id = str(output).split(" ")[1].replace("<", "").replace(">", "") - self._realstate_to_lsfid[realization] = lsf_id - self._lsfid_to_realstate[lsf_id] = realization - realization.accept() - print(f"Submitted job {realization} and got LSF JOBID {lsf_id}") - except Exception: - # We should probably retry the submission, bsub stdout seems flaky. - print(f"ERROR: Could not parse lsf id from: {output!r}") + output, _error = await process.communicate() + if process.returncode != 0: + logger.error( + ( + f"{command_name} returned non-zero exitcode: {process.returncode}\n" + f"{output.decode()}\n" + f"{_error.decode()}" + ) + ) + return None + return (process, output, _error) async def poll_statuses(self) -> None: if self._currently_polling: + logger.debug("Already polling status elsewhere") return - self._currently_polling = True if not self._realstate_to_lsfid: # Nothing has been submitted yet. + logger.warning("Skipped polling due to no jobs submitted") return - poll_cmd = ["bjobs"] + list(self._realstate_to_lsfid.values()) - assert shutil.which(poll_cmd[0]) # does not propagate back.. - process = await asyncio.create_subprocess_exec( - *poll_cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - output, _error = await process.communicate() + poll_cmd = [ + str(self.get_option("BJOBS_CMD")) + if self.has_option("BJOBS_CMD") + else "bjobs" + ] + list(self._realstate_to_lsfid.values()) + try: + await self.run_with_retries(lambda: self._poll_statuses(poll_cmd)) + # suppress runtime error + except RuntimeError: + return + except ValueError as e: + # raise this value error as runtime error + raise RuntimeError(e) + + async def _poll_statuses(self, poll_cmd: str) -> bool: + self._currently_polling = True + result = await self.run_shell_command(poll_cmd, command_name="bjobs") + + if result is None: + return False + (_, output, _) = result + for line in output.decode(encoding="utf-8").split("\n"): if "JOBID" in line: continue @@ -172,7 +264,8 @@ async def poll_statuses(self) -> None: continue if tokens[0] not in self._lsfid_to_realstate: # A LSF id we know nothing of, this should not happen. - continue + raise ValueError(f"Found unknown job id ({tokens[0]})") + realstate = self._lsfid_to_realstate[tokens[0]] if tokens[2] == "PEND" and str(realstate.current_state.id) == "WAITING": @@ -194,9 +287,31 @@ async def poll_statuses(self) -> None: realstate.runend() if tokens[2] == "DONE" and str(realstate.current_state.id) == "RUNNING": realstate.runend() + if tokens[2] not in LSFDriver.LSF_STATUSES: + raise ValueError( + f"The lsf_status {tokens[2]} for job {tokens[0]} was not recognized\n" + ) self._currently_polling = False + return True async def kill(self, realization: "RealizationState") -> None: - print(f"would like to kill {realization}") - pass + lsf_job_id = self._realstate_to_lsfid[realization] + logger.debug(f"Attempting to kill {lsf_job_id=}") + kill_cmd = [ + self.get_option("BKILL_CMD") if self.has_option("BKILL_CMD") else "bkill", + lsf_job_id, + ] + await self.run_with_retries( + lambda: self._kill(kill_cmd, realization, lsf_job_id), + error_msg="Maximum number of kill errors exceeded\n", + ) + + async def _kill( + self, kill_cmd, realization: "RealizationState", lsf_job_id: int + ) -> bool: + result = await self.run_shell_command(kill_cmd, "bkill") + if result is None: + return False + realization.verify_kill() + logger.info(f"Successfully killed job {lsf_job_id}") diff --git a/tests/unit_tests/job_queue/test_lsf_driver.py b/tests/unit_tests/job_queue/test_lsf_driver.py index 24019ab14e0..230c616e995 100644 --- a/tests/unit_tests/job_queue/test_lsf_driver.py +++ b/tests/unit_tests/job_queue/test_lsf_driver.py @@ -1,71 +1,110 @@ +import logging import os +import re from argparse import ArgumentParser +from datetime import datetime from pathlib import Path from textwrap import dedent +from typing import Dict, List, Tuple import pytest from ert.__main__ import ert_parser from ert.cli import ENSEMBLE_EXPERIMENT_MODE from ert.cli.main import run_cli +from ert.job_queue.driver import LSFDriver + + +@pytest.fixture +def mock_bsub(tmp_path): + script_path = tmp_path / "mock_bsub" + script_path.write_text( + "#!/usr/bin/env python3" + + dedent( + """ + import sys + import time + import random + run_path = sys.argv[-1] + with open("job_paths", "a+", encoding="utf-8") as jobs_file: + jobs_file.write(f"{run_path}\\n") + + # debug purposes + with open("bsub_log", "a+", encoding="utf-8") as f: + f.write(f"{' '.join(sys.argv)}\\n") + + time.sleep(0.5) + + if "exit.sh" in sys.argv: + exit(1) + + if "gargled_return.sh" in sys.argv: + print("wait,this_is_not_a_valid_return_format") + else: + _id = str(random.randint(0, 10000000)) + print(f"Job <{_id}> is submitted to default queue <normal>.") + """ + ) + ) + os.chmod(script_path, 0o755) + + +@pytest.fixture +def mock_bkill(tmp_path): + script_path = tmp_path / "mock_bkill" + script_path.write_text( + "#!/usr/bin/env python3" + + dedent( + """ + import sys + import time + import random + job_id = sys.argv[-1] + with open("job_ids", "a+", encoding="utf-8") as jobs_file: + jobs_file.write(f"{job_id}\\n") + + time.sleep(0.5) + + if job_id == "non_existent_jobid": + print(f"bkill: jobid {job_id} not found") + exit(1) + """ + ) + ) + os.chmod(script_path, 0o755) @pytest.fixture def mock_bjobs(tmp_path): script = "#!/usr/bin/env python3" + dedent( """ - import datetime - import json - import os.path - import sys - - timestamp = str(datetime.datetime.now()) - - # File written from the mocked bsub command which provides us with - # the path to where the job actually runs and where we can find i.e - # the job_id and status - with open("job_paths", encoding="utf-8") as job_paths_file: - job_paths = job_paths_file.read().splitlines() - - print("JOBID USER STAT QUEUE FROM_HOST EXEC_HOST JOB_NAME SUBMIT_TIME") - for line in job_paths: - if not os.path.isfile(line + "/lsf_info.json"): - continue - - # ERT has picked up the mocked response from mock_bsub and - # written the id to file - with open(line + "/lsf_info.json") as id_file: - _id = json.load(id_file)["job_id"] - - # Statuses LSF can give us - # "PEND" - # "SSUSP" - # "PSUSP" - # "USUSP" - # "RUN" - # "EXIT" - # "ZOMBI" : does not seem to be available from the api. - # "DONE" - # "PDONE" : Post-processor is done. - # "UNKWN" - status = "RUN" - if os.path.isfile(f"{line}/OK"): - status = "DONE" - - # Together with the headerline this is actually how LSF is - # providing its statuses on the job and how we are picking these - # up. In this mocked version i just check if the job is done - # with the OK file and then print that status for the job_id - # retrieved from the same runpath. - print( - f"{_id} pytest {status} normal" - f" mock_host mock_exec_host poly_0 {timestamp}" - ) + import datetime + import json + import os.path + import sys # Just to have a log for test purposes what is actually thrown # towards the bjobs command - with open("bjobs_log", "a+", encoding="utf-8") as f: - f.write(f"{str(sys.argv)}\\n") + with open("bjobs_log", "a+", encoding="utf-8") as f: + f.write(f"{str(sys.argv)}\\n") + print("JOBID\tUSER\tSTAT\tQUEUE\tFROM_HOST\tEXEC_HOST\tJOB_NAME\tSUBMIT_TIME") + + # Statuses LSF can give us + # "PEND" + # "SSUSP" + # "PSUSP" + # "USUSP" + # "RUN" + # "EXIT" + # "ZOMBI" : does not seem to be available from the api. + # "DONE" + # "PDONE" : Post-processor is done. + # "UNKWN" + with open("mocked_result", mode="r+", encoding="utf-8") as result_line: + result = result_line.read() + if "exit" in result.split("\t"): + exit(1) + print(result) """ ) script_path = tmp_path / "mock_bjobs" @@ -75,75 +114,377 @@ def mock_bjobs(tmp_path): os.chmod(script_path, 0o755) -def make_mock_bsub(script_path): - script_path.write_text( - "#!/usr/bin/env python3" - + dedent( - """ - import random - import subprocess - import sys - - job_dispatch_path = sys.argv[-2] - run_path = sys.argv[-1] - - # Write a file with the runpaths to where the jobs are running and - # writing information we later need when providing statuses for the - # jobs through the mocked bjobs command - with open("job_paths", "a+", encoding="utf-8") as jobs_file: - jobs_file.write(f"{run_path}\\n") - - # Just a log for testpurposes showing what is thrown against the - # bsub command - with open("bsub_log", "a+", encoding="utf-8") as f: - f.write(f"{str(sys.argv)}\\n") - - # Assigning a "unique" job id for each submitted job and print. This - # is how LSF provide response to ERT with the ID of the job. - _id = str(random.randint(0, 10000000)) - print(f"Job <{_id}> is submitted to default queue <normal>.") - - # Launch job-dispatch - subprocess.Popen([job_dispatch_path, run_path]) - """ - ) +class MockStateHandler: + id = "SUBMITTED" + + +class MockRunArg: + runpath = "/usr/random/ert_path" + + def iens(): + return 0 + + +class MockQueueableRealization: + run_arg = MockRunArg() + + +class MockRealizationState: + _state = "SUBMITTED" + _verified_killed = False + realization = MockQueueableRealization() + current_state = MockStateHandler() + + def verify_kill(self): + self._verified_killed = True + print("Realization was verified killed") + + def accept(self): + self.current_state.id = "PENDING" + + def start(self): + self.current_state.id = "RUNNING" + + def runend(self): + self.current_state.id = "DONE" + + +def create_fake_bjobs_result(dir: str, job_id: str, status: str): + # print("JOBID USER STAT QUEUE FROM_HOST EXEC_HOST JOB_NAME SUBMIT_TIME") + Path(dir / "mocked_result").write_text( + f"{job_id}\tpytest\t{status}\ttest_queue\thost1\thost2\ttest_job\t{str(datetime.now())}" ) - os.chmod(script_path, 0o755) -def make_failing_bsub(script_path, success_script): - """ - Approx 3/10 of the submits will fail due to the random generator in the - created mocked bsub script. By using the retry functionality towards - queue-errors in job_queue.cpp we should still manage to finalize all our runs - before exhausting the limits - """ - script_path.write_text( - "#!/usr/bin/env python3" - + dedent( - f""" - import random - import sys - import subprocess - - num = random.random() - if num > 0.7: - exit(1) - subprocess.call(["python", "{success_script}"] + sys.argv) - """ - ) +@pytest.mark.asyncio +async def test_submit_failure_script_exit(mock_bsub, caplog, tmpdir, monkeypatch): + monkeypatch.chdir(tmpdir) + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BSUB_CMD", tmpdir / "mock_bsub") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + + mock_realization_state = MockRealizationState() + mock_realization_state.realization.job_script = "exit.sh" + + with pytest.raises(RuntimeError, match="Maximum number of submit errors exceeded"): + await lsf_driver.submit(mock_realization_state) + + job_paths = Path("job_paths").read_text(encoding="utf-8").strip().split("\n") + + # should try command 3 times before exiting + assert len(job_paths) == 3 + + output = caplog.text + assert len(re.findall("bsub returned non-zero exitcode: 1", output)) == 3 + + +@pytest.mark.asyncio +async def test_submit_failure_badly_formated_return( + mock_bsub, caplog, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BSUB_CMD", tmpdir / "mock_bsub") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + + mock_realization_state = MockRealizationState() + mock_realization_state.realization.job_script = "gargled_return.sh" + + with pytest.raises(RuntimeError, match="Maximum number of submit errors exceeded"): + await lsf_driver.submit(mock_realization_state) + + job_paths = Path("job_paths").read_text(encoding="utf-8").strip().split("\n") + + # should try command 3 times before exiting + assert len(job_paths) == 3 + + output = caplog.text + print(f"{output=}") + assert len(re.findall("Could not parse lsf id from", output)) == 3 + + +@pytest.mark.asyncio +async def test_submit_success(mock_bsub, caplog, tmpdir, monkeypatch): + caplog.set_level(logging.DEBUG) + monkeypatch.chdir(tmpdir) + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BSUB_CMD", tmpdir / "mock_bsub") + lsf_driver._max_attempt = 3 + + mock_realization_state = MockRealizationState() + mock_realization_state.realization.job_script = "valid_script.sh" + + await lsf_driver.submit(mock_realization_state) + + job_paths = Path("job_paths").read_text(encoding="utf-8").strip().split("\n") + assert len(job_paths) == 1 + output = caplog.text + assert re.search("Submitted job.*and got LSF JOBID", output) + assert re.search("submitted to default queue", output) + assert mock_realization_state.current_state.id == "PENDING" + + +@pytest.mark.asyncio +async def test_poll_statuses_while_already_polling( + mock_bjobs, caplog, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(None) + lsf_driver._currently_polling = True + lsf_driver._statuses.update({"test_lsf_job_id": "RUNNING"}) + await lsf_driver.poll_statuses() + + # Should never call bjobs + assert not Path("bjobs_logs").exists() + + output = caplog.text + assert output == "" + assert lsf_driver._currently_polling + + +@pytest.mark.asyncio +async def test_poll_statuses_before_submitting_jobs(): + lsf_driver = LSFDriver(None) + + # should not crash + await lsf_driver.poll_statuses() + + +@pytest.mark.asyncio +async def test_poll_statuses_bjobs_exit_code_1(mock_bjobs, caplog, tmpdir, monkeypatch): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BJOBS_CMD", tmpdir / "mock_bjobs") + lsf_driver._max_attempt = 3 + + # will return a job id triggering exit(1) in bjobs + create_fake_bjobs_result(tmpdir, job_id="exit", status="PEND") + + mock_realization_state = MockRealizationState() + lsf_driver._realstate_to_lsfid[mock_realization_state] = "valid_job_id" + lsf_driver._lsfid_to_realstate["valid_job_id"] = mock_realization_state + + # should print out and ignore the unknown job id + await lsf_driver.poll_statuses() + + bjobs_logs = Path("bjobs_log").read_text(encoding="utf-8").strip().split("\n") + + # Should only call bjobs once + assert len(bjobs_logs) == 3 + + output = caplog.text + assert len(re.findall("bjobs returned non-zero exitcode: 1", output)) == 3 + assert ( + mock_realization_state.current_state.id + == MockRealizationState().current_state.id ) - os.chmod(script_path, 0o755) +@pytest.mark.asyncio +async def test_poll_statuses_bjobs_returning_unknown_job_id( + mock_bjobs, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BJOBS_CMD", tmpdir / "mock_bjobs") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + # will return a job id not belonging to this run + create_fake_bjobs_result(tmpdir, job_id="unknown_job_id", status="PEND") + + mock_realization_state = MockRealizationState() + lsf_driver._realstate_to_lsfid[mock_realization_state] = "valid_job_id" + lsf_driver._lsfid_to_realstate["valid_job_id"] = mock_realization_state + + # should print out and ignore the unknown job id + with pytest.raises(RuntimeError, match="Found unknown job id \\(unknown_job_id\\)"): + await lsf_driver.poll_statuses() + + bjobs_logs = Path("bjobs_log").read_text(encoding="utf-8").strip().split("\n") + + # Should only call bjobs once + assert len(bjobs_logs) == 1 + assert ( + mock_realization_state.current_state.id == MockRealizationState.current_state.id + ) + + +@pytest.mark.asyncio +async def test_poll_statuses_bjobs_returning_unrecognized_status( + mock_bjobs, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BJOBS_CMD", tmpdir / "mock_bjobs") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + create_fake_bjobs_result(tmpdir, job_id="valid_job_id", status="EATING") + + mock_realization_state = MockRealizationState() + lsf_driver._realstate_to_lsfid[mock_realization_state] = "valid_job_id" + lsf_driver._lsfid_to_realstate["valid_job_id"] = mock_realization_state + + with pytest.raises( + RuntimeError, + match="The lsf_status EATING for job valid_job_id was not recognized", + ): + await lsf_driver.poll_statuses() + + bjobs_logs = Path("bjobs_log").read_text(encoding="utf-8").strip().split("\n") + + # Should only call bjobs once + assert len(bjobs_logs) == 1 + + +@pytest.mark.asyncio +async def test_poll_statuses_bjobs_returning_updated_state( + mock_bjobs, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BJOBS_CMD", tmpdir / "mock_bjobs") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + create_fake_bjobs_result(tmpdir, job_id="valid_job_id", status="RUN") + + mock_realization_state = MockRealizationState() + mock_realization_state.current_state.id = "PENDING" + lsf_driver._realstate_to_lsfid[mock_realization_state] = "valid_job_id" + lsf_driver._lsfid_to_realstate["valid_job_id"] = mock_realization_state + + await lsf_driver.poll_statuses() + + bjobs_logs = Path("bjobs_log").read_text(encoding="utf-8").strip().split("\n") + + # Should only call bjobs once + assert len(bjobs_logs) == 1 + + # Should update realization state + assert mock_realization_state.current_state.id == "RUNNING" + + +@pytest.mark.asyncio +async def test_kill_bkill_non_existent_jobid_exit_code_1( + mock_bkill, caplog, tmpdir, monkeypatch +): + monkeypatch.chdir(tmpdir) + lsf_driver = LSFDriver(None) + lsf_driver.set_option("BKILL_CMD", tmpdir / "mock_bkill") + lsf_driver._max_attempt = 3 + lsf_driver._retry_sleep_period = 0 + mock_realization_state = MockRealizationState() + lsf_driver._realstate_to_lsfid[mock_realization_state] = "non_existent_jobid" + + with pytest.raises(RuntimeError, match="Maximum number of kill errors exceeded"): + await lsf_driver.kill(mock_realization_state) + + output = caplog.text + out_log = output.split("\n") + job_ids_from_file = Path("job_ids").read_text(encoding="utf-8").strip().split("\n") + assert len(job_ids_from_file) == lsf_driver._max_attempt + print(f"{out_log=}") + assert ( + len(re.findall("bkill: jobid non_existent_jobid not found", output)) + == lsf_driver._max_attempt + ) + + assert ( + len(re.findall("returned non-zero exitcode: 1", output)) + == lsf_driver._max_attempt + ) + + +@pytest.mark.parametrize( + "options, expected_list", + [ + [[("LSF_QUEUE", "pytest_queue")], ["bsub", "-q", "pytest_queue"]], + [[("BSUB_CMD", "/bin/mock/bsub")], ["/bin/mock/bsub"]], + ], +) +def test_lsf_parse_submit_cmd_adds_driver_options( + options: list[Tuple[str, str]], expected_list +): + lsf_driver = LSFDriver(options) + submit_command_list = lsf_driver.parse_submit_cmd() + assert submit_command_list == expected_list + + +@pytest.mark.parametrize( + "additional_parameters", [[["test0", "test2", "/home/test3.py"]], [[3, 2]], [[]]] +) +def test_lsf_parse_submit_cmd_adds_additional_parameters( + additional_parameters: list[str], +): + lsf_driver = LSFDriver(None) + submit_command_list = lsf_driver.parse_submit_cmd(*additional_parameters) + assert submit_command_list == ["bsub"] + additional_parameters + + +@pytest.mark.parametrize( + "options, additional_parameters, expected_list", + [ + [ + [("LSF_QUEUE", "pytest_queue")], + ["test0", "test2", "/home/test3.py"], + ["bsub", "-q", "pytest_queue", "test0", "test2", "/home/test3.py"], + ], + [ + [("LSF_QUEUE", "pytest_queue"), ("BSUB_CMD", "/bin/test_bsub")], + ["test0", "test2", "/home/test3.py"], + [ + "/bin/test_bsub", + "-q", + "pytest_queue", + "test0", + "test2", + "/home/test3.py", + ], + ], + ], +) +def test_lsf_parse_submit_cmd_adds_additional_parameters_after_options( + options: list[tuple[str, str]], + additional_parameters: list[str], + expected_list: list[str], +): + lsf_driver = LSFDriver(options) + submit_command_list = lsf_driver.parse_submit_cmd(*additional_parameters) + assert submit_command_list == expected_list + + +@pytest.mark.parametrize( + "driver_options, expected_bsub_options", + [[[("LSF_QUEUE", "test_queue")], ["-q test_queue"]]], +) +@pytest.mark.asyncio +async def test_lsf_submit_lsf_queue_option_is_added( + driver_options: list[Tuple[str, str]], + expected_bsub_options: list[str], + mock_bsub, + tmpdir, + monkeypatch, +): + monkeypatch.chdir(tmpdir) + + lsf_driver = LSFDriver(driver_options) + lsf_driver.set_option("BSUB_CMD", tmpdir / "mock_bsub") + + mock_realization_state = MockRealizationState() + mock_realization_state.realization.job_script = "valid_script.sh" + await lsf_driver.submit(mock_realization_state) + + command_called = Path("bsub_log").read_text(encoding="utf-8").strip() + assert len(command_called.split("\n")) == 1 -@pytest.fixture(params=["success", "fail"]) -def mock_bsub(request, tmp_path): - if request.param == "success": - return make_mock_bsub(tmp_path / "mock_bsub") - else: - make_mock_bsub(tmp_path / "success_bsub") - return make_failing_bsub(tmp_path / "mock_bsub", tmp_path / "success_bsub") + for expected_bsub_option in expected_bsub_options: + assert expected_bsub_option in command_called @pytest.fixture @@ -171,6 +512,7 @@ def copy_lsf_poly_case(copy_poly_case, tmp_path): fh.writelines(config) +@pytest.mark.skip(reason="Integration Test - does not work with the new python driver") @pytest.mark.usefixtures( "copy_lsf_poly_case", "mock_bsub",