Skip to content

Commit

Permalink
Review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jun 12, 2024
1 parent 047b043 commit 52c5e24
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 37 deletions.
46 changes: 23 additions & 23 deletions src/ert/shared/hook_implementations/workflows/export_runpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,24 +38,30 @@ def run(
) -> None:
input_ranges = [] if input_ranges is None else input_ranges
_args = " ".join(input_ranges).split() # Make sure args is a list of words
config = ert_config
self.ert_config = ert_config
run_paths = Runpaths(
jobname_format=config.model_config.jobname_format_string,
runpath_format=config.model_config.runpath_format_string,
filename=str(config.runpath_file),
substitution_list=config.substitution_list,
jobname_format=ert_config.model_config.jobname_format_string,
runpath_format=ert_config.model_config.runpath_format_string,
filename=str(ert_config.runpath_file),
substitution_list=ert_config.substitution_list,
)
run_paths.write_runpath_list(
*self.get_ranges(
_args,
ert_config.analysis_config.num_iterations,
ert_config.model_config.num_realizations,
)
)
run_paths.write_runpath_list(*self.get_ranges(_args))

def get_ranges(self, args: List[str]) -> Tuple[List[int], List[int]]:
realizations_rangestring, iterations_rangestring = self._get_rangestrings(args)
def get_ranges(
self, args: List[str], number_of_iterations: int, number_of_realizations: int
) -> Tuple[List[int], List[int]]:
realizations_rangestring, iterations_rangestring = self._get_rangestrings(
args, number_of_realizations
)
return (
self._list_from_rangestring(iterations_rangestring, number_of_iterations),
self._list_from_rangestring(
iterations_rangestring, self.number_of_iterations
),
self._list_from_rangestring(
realizations_rangestring, self.number_of_realizations
realizations_rangestring, number_of_realizations
),
)

Expand All @@ -67,21 +73,15 @@ def _list_from_rangestring(rangestring: str, size: int) -> List[int]:
mask = rangestring_to_mask(rangestring, size)
return [i for i, flag in enumerate(mask) if flag]

def _get_rangestrings(self, args: List[str]) -> Tuple[str, str]:
def _get_rangestrings(
self, args: List[str], number_of_realizations: int
) -> Tuple[str, str]:
if not args:
return (
f"0-{self.number_of_realizations-1}",
f"0-{number_of_realizations-1}",
"0-0", # weird default behavior, kept for backwards compatability
)
if "|" not in args:
raise ValueError("Expected | in EXPORT_RUNPATH arguments")
delimiter = args.index("|")
return " ".join(args[:delimiter]), " ".join(args[delimiter + 1 :])

@property
def number_of_realizations(self) -> int:
return self.ert_config.model_config.num_realizations

@property
def number_of_iterations(self) -> int:
return self.ert_config.analysis_config.num_iterations
19 changes: 5 additions & 14 deletions tests/unit_tests/all/plugins/test_export_runpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,11 @@

@pytest.fixture
def snake_oil_export_runpath_job(setup_case):
ert_config = setup_case("snake_oil", "snake_oil.ert")
setup_case("snake_oil", "snake_oil.ert")
plugin = ExportRunpathJob()
plugin.ert_config = ert_config
yield plugin


def test_export_runpath_number_of_realizations(snake_oil_export_runpath_job):
assert snake_oil_export_runpath_job.number_of_realizations == 25


def test_export_runpath_number_of_iterations(snake_oil_export_runpath_job):
assert snake_oil_export_runpath_job.number_of_iterations == 4


@dataclass
class WritingSetup:
write_mock: Mock
Expand All @@ -43,7 +34,7 @@ def test_export_runpath_empty_range(writing_setup):

writing_setup.write_mock.assert_called_with(
[0],
list(range(writing_setup.export_job.number_of_realizations)),
list(range(25)),
)


Expand All @@ -52,8 +43,8 @@ def test_export_runpath_star_parameter(writing_setup):
writing_setup.export_job.run(config, ["* | *"])

writing_setup.write_mock.assert_called_with(
list(range(writing_setup.export_job.number_of_iterations)),
list(range(writing_setup.export_job.number_of_realizations)),
list(range(4)),
list(range(25)),
)


Expand All @@ -63,7 +54,7 @@ def test_export_runpath_range_parameter(writing_setup):

writing_setup.write_mock.assert_called_with(
[1, 2],
list(range(writing_setup.export_job.number_of_realizations)),
list(range(25)),
)


Expand Down

0 comments on commit 52c5e24

Please sign in to comment.