diff --git a/flepimop/gempyor_pkg/src/gempyor/batch.py b/flepimop/gempyor_pkg/src/gempyor/batch.py index ac90d7692..701828662 100644 --- a/flepimop/gempyor_pkg/src/gempyor/batch.py +++ b/flepimop/gempyor_pkg/src/gempyor/batch.py @@ -31,11 +31,11 @@ from .logging import get_script_logger from .utils import _format_cli_options, _git_checkout, _git_head, _shutil_which, config from .shared_cli import ( - MEMORY_MB, - NONNEGATIVE_DURATION, + DurationParamType, + MemoryParamType, cli, - config_files_argument, config_file_options, + config_files_argument, log_cli_inputs, mock_context, parse_config_files, @@ -903,13 +903,13 @@ def _submit_scenario_job( ), click.Option( param_decls=["--simulation-time", "simulation_time"], - type=NONNEGATIVE_DURATION, + type=DurationParamType(True), default="3min", help="The time limit per a simulation.", ), click.Option( param_decls=["--initial-time", "initial_time"], - type=NONNEGATIVE_DURATION, + type=DurationParamType(True), default="20min", help="The initialization time limit.", ), @@ -971,7 +971,7 @@ def _submit_scenario_job( ), click.Option( param_decls=["--memory", "memory"], - type=MEMORY_MB, + type=MemoryParamType("mb", as_int=True), default=None, help="Override for the amount of memory per node to use in MB.", ), @@ -1130,19 +1130,12 @@ def _click_submit(ctx: click.Context = mock_context, **kwargs: Any) -> None: logger.info("Setting a total job time limit of %s minutes", job_time_limit.format()) # Job resources - memory = None if kwargs["memory"] is None else math.ceil(kwargs["memory"]) - if memory != kwargs["memory"]: - logger.warning( - "The requested memory of %.3fMB has been rounded up to %uMB for submission", - kwargs["memory"], - memory, - ) job_resources = JobResources.from_presets( job_size, inference_method, nodes=kwargs["nodes"], cpus=kwargs["cpus"], - memory=memory, + memory=kwargs["memory"], ) logger.info("Requesting the resources %s for this job.", job_resources) diff --git a/flepimop/gempyor_pkg/src/gempyor/info.py b/flepimop/gempyor_pkg/src/gempyor/info.py index 00d094c09..bdcdc7b36 100644 --- a/flepimop/gempyor_pkg/src/gempyor/info.py +++ b/flepimop/gempyor_pkg/src/gempyor/info.py @@ -1,3 +1,31 @@ +""" +Retrieving static information from developer managed yaml files. + +Currently, it includes utilities for handling cluster-specific information, but it can +be extended to other categories as needed. + +Classes: + Module: Represents a software module with a name and optional version. + PathExport: Represents a path export with a path, prepend flag, and error handling. + Cluster: Represents a cluster with a name, list of modules, and list of path + exports. + +Functions: + get_cluster_info: Retrieves cluster-specific information. + +Examples: + >>> from pprint import pprint + >>> from gempyor.info import get_cluster_info + >>> cluster_info = get_cluster_info("longleaf") + >>> cluster_info.name + 'longleaf' + >>> pprint(cluster_info.modules) + [Module(name='gcc', version='9.1.0'), + Module(name='anaconda', version='2023.03'), + Module(name='git', version=None), + Module(name='aws', version=None)] +""" + __all__ = ["Cluster", "Module", "PathExport", "get_cluster_info"] @@ -12,23 +40,55 @@ class Module(BaseModel): + """ + A model representing a module to load. + + Attributes: + name: The name of the module to load. + version: The specific version of the module to load if there is one. + + See Also: + [Lmod](https://lmod.readthedocs.io/en/latest/) + """ + name: str version: str | None = None class PathExport(BaseModel): + """ + A model representing the export path configuration. + + Attributes: + path: The file system path of the path to add to the `$PATH` environment + variable. + prepend: A flag indicating whether to prepend additional information to the + `$PATH` environment variable. + error_if_missing: A flag indicating whether to raise an error if the path is + missing. + """ + path: Path prepend: bool = True error_if_missing: bool = False class Cluster(BaseModel): + """ + A model representing a cluster configuration. + + Attributes: + name: The name of the cluster. + modules: A list of modules associated with the cluster. + path_exports: A list of path exports for the cluster. + """ + name: str modules: list[Module] = [] path_exports: list[PathExport] = [] -T = TypeVar("T", bound=BaseModel) +_BASE_MODEL_TYPE = TypeVar("T", bound=BaseModel) _CLUSTER_FQDN_REGEXES: tuple[tuple[str, Pattern], ...] = ( @@ -38,8 +98,8 @@ class Cluster(BaseModel): def _get_info( - category: str, name: str, model: type[T], flepi_path: os.PathLike | None -) -> T: + category: str, name: str, model: type[_BASE_MODEL_TYPE], flepi_path: os.PathLike | None +) -> _BASE_MODEL_TYPE: """ Get and parse an information yaml file. @@ -79,8 +139,14 @@ def get_cluster_info(name: str | None, flepi_path: os.PathLike | None = None) -> flepi_path: Either a path like determine the directory to look for the info directory in or `None` to use the `FLEPI_PATH` environment variable. - Returns + Returns: An object containing the information about the `name` cluster. + + Examples: + >>> from gempyor.info import get_cluster_info + >>> cluster_info = get_cluster_info("longleaf") + >>> cluster_info.name + 'longleaf' """ name = _infer_cluster_from_fqdn() if name is None else name return _get_info("cluster", name, Cluster, flepi_path) diff --git a/flepimop/gempyor_pkg/src/gempyor/shared_cli.py b/flepimop/gempyor_pkg/src/gempyor/shared_cli.py index d47d6de96..230839784 100644 --- a/flepimop/gempyor_pkg/src/gempyor/shared_cli.py +++ b/flepimop/gempyor_pkg/src/gempyor/shared_cli.py @@ -7,6 +7,7 @@ from datetime import timedelta +from math import ceil import multiprocessing import pathlib import re @@ -143,6 +144,23 @@ def cli(ctx: click.Context) -> None: class DurationParamType(click.ParamType): + """ + A custom Click parameter type for parsing duration strings into `timedelta` objects. + + Attributes: + name: The name of the parameter type. + + Examples: + >>> from gempyor.shared_cli import DurationParamType + >>> duration_param_type = DurationParamType(False) + >>> duration_param_type.convert("23min", None, None) + datetime.timedelta(seconds=1380) + >>> duration_param_type.convert("2.5hr", None, None) + datetime.timedelta(seconds=9000) + >>> duration_param_type.convert("-2", None, None) + datetime.timedelta(days=-1, seconds=86280) + """ + name = "duration" _abbreviations = { "s": "seconds", @@ -173,6 +191,14 @@ def __init__( nonnegative: bool, default_unit: Literal["seconds", "minutes", "hours", "days", "weeks"] = "minutes", ) -> None: + """ + Initialize the instance based on parameter settings. + + Args: + nonnegative: If `True` negative durations are not allowed. + default_unit: The default unit to use if no unit is specified in the input + string. + """ super().__init__() self._nonnegative = nonnegative self._duration_regex = re.compile( @@ -184,6 +210,24 @@ def __init__( def convert( self, value: Any, param: click.Parameter | None, ctx: click.Context | None ) -> timedelta: + """ + Converts a string representation of a duration into a `timedelta` object. + + Args: + value: The value to convert, expected to be a string like representation of + a duration. + param: The Click parameter object for context in errors. + ctx: The Click context object for context in errors. + + Returns: + The converted duration as a `timedelta` object. + + Raises: + click.BadParameter: If the value is not a valid duration based on the + format. + click.BadParameter: If the duration is negative and the class was + initialized with `nonnegative` set to `True`. + """ value = str(value).strip() if (m := self._duration_regex.match(value)) is None: self.fail(f"{value!r} is not a valid duration", param, ctx) @@ -195,11 +239,24 @@ def convert( return timedelta(**kwargs) -DURATION = DurationParamType(nonnegative=False) -NONNEGATIVE_DURATION = DurationParamType(nonnegative=True) +class MemoryParamType(click.ParamType): + """ + A custom Click parameter type for parsing duration strings into `timedelta` objects. + Attributes: + name: The name of the parameter type. + + Examples: + >>> from gempyor.shared_cli import DurationParamType + >>> duration_param_type = DurationParamType(False) + >>> duration_param_type.convert("23min", None, None) + datetime.timedelta(seconds=1380) + >>> duration_param_type.convert("2.5hr", None, None) + datetime.timedelta(seconds=9000) + >>> duration_param_type.convert("-2", None, None) + datetime.timedelta(days=-1, seconds=86280) + """ -class MemoryParamType(click.ParamType): name = "memory" _units = { "kb": 1024.0**1.0, @@ -212,38 +269,63 @@ class MemoryParamType(click.ParamType): "tb": 1024.0**4.0, } - def __init__(self, unit: str) -> None: + def __init__(self, unit: str, as_int: bool = False) -> None: + """ + Initialize the instance based on parameter settings. + + Args: + unit: The output unit to use in the `convert` method. + as_int: if `True` the `convert` method returns an integer instead of a + float. + + Raises: + ValueError: If `unit` is not a valid memory unit size. + """ super().__init__() if (unit := unit.lower()) not in self._units.keys(): raise ValueError( f"The `unit` given is not valid, given '{unit}' and " - "must be one of: {', '.join(self._units.keys())}." + f"must be one of: {', '.join(self._units.keys())}." ) self._unit = unit self._regex = re.compile( rf"^(([0-9]+)?(\.[0-9]+)?)({'|'.join(self._units.keys())})?$", flags=re.IGNORECASE, ) + self._as_int = as_int def convert( self, value: Any, param: click.Parameter | None, ctx: click.Context | None - ) -> float: + ) -> float | int: + """ + Converts a string representation of a memory size into a numeric. + + Args: + value: The value to convert, expected to be a string like representation of + memory size. + param: The Click parameter object for context in errors. + ctx: The Click context object for context in errors. + + Returns: + The converted memory size as a numeric. Specifically an integer if the + `as_int` attribute is `True` and float otherwise. + + Raises: + click.BadParameter: If the value is not a valid memory size based on the + format. + """ value = str(value).strip() if (m := self._regex.match(value)) is None: self.fail(f"{value!r} is not a valid memory size.", param, ctx) number, _, _, unit = m.groups() unit = unit.lower() if unit == self._unit: - return float(number) - return (self._units.get(unit, self._unit) * float(number)) / ( - self._units.get(self._unit) - ) - - -MEMORY_KB = MemoryParamType("kb") -MEMORY_MB = MemoryParamType("mb") -MEMORY_GB = MemoryParamType("gb") -MEMORY_TB = MemoryParamType("tb") + result = float(number) + else: + result = (self._units.get(unit, self._unit) * float(number)) / ( + self._units.get(self._unit) + ) + return ceil(result) if self._as_int else result def click_helpstring( diff --git a/flepimop/gempyor_pkg/tests/shared_cli/test_memory_param_type_class.py b/flepimop/gempyor_pkg/tests/shared_cli/test_memory_param_type_class.py index 1db89a5a7..f3c60e504 100644 --- a/flepimop/gempyor_pkg/tests/shared_cli/test_memory_param_type_class.py +++ b/flepimop/gempyor_pkg/tests/shared_cli/test_memory_param_type_class.py @@ -27,6 +27,7 @@ def test_invalid_value_bad_parameter(value: Any) -> None: @pytest.mark.parametrize("unit", MemoryParamType._units.keys()) +@pytest.mark.parametrize("as_int", (True, False)) @pytest.mark.parametrize( "number", [random.randint(1, 1000) for _ in range(3)] # int @@ -35,22 +36,31 @@ def test_invalid_value_bad_parameter(value: Any) -> None: random.randint(1, 25) + random.random() for _ in range(3) ], # float with numbers left of the decimal ) -def test_convert_acts_as_identity(unit: str, number: int) -> None: - memory = MemoryParamType(unit) - assert memory.convert(f"{number}{unit}".lstrip("0"), None, None) == number - assert memory.convert(f"{number}{unit.upper()}".lstrip("0"), None, None) == number +def test_convert_acts_as_identity(unit: str, as_int: bool, number: int | float) -> None: + memory = MemoryParamType(unit, as_int=as_int) + for u in (unit, unit.upper()): + result = memory.convert(f"{number}{u}".lstrip("0"), None, None) + assert isinstance(result, int if as_int else float) + assert abs(result - number) <= 1 if as_int else result == number @pytest.mark.parametrize( - ("unit", "value", "expected"), + ("unit", "as_int", "value", "expected"), ( - ("gb", "1.2gb", 1.2), - ("kb", "1mb", 1024.0), - ("gb", "30mb", 30.0 / 1024.0), - ("kb", "2tb", 2.0 * (1024.0**3.0)), - ("mb", "0.1gb", 0.1 * 1024.0), + ("gb", False, "1.2gb", 1.2), + ("gb", True, "1.2gb", 2), + ("kb", False, "1mb", 1024.0), + ("kb", True, "1mb", 1024), + ("gb", False, "30mb", 30.0 / 1024.0), + ("gb", True, "30mb", 1), + ("kb", False, "2tb", 2.0 * (1024.0**3.0)), + ("kb", True, "2tb", 2147483648), + ("mb", False, "0.1gb", 0.1 * 1024.0), + ("mb", True, "0.1gb", 103), ), ) -def test_exact_results_for_select_inputs(unit: str, value: Any, expected: float) -> None: - memory = MemoryParamType(unit) +def test_exact_results_for_select_inputs( + unit: str, as_int: bool, value: Any, expected: float | int +) -> None: + memory = MemoryParamType(unit, as_int=as_int) assert memory.convert(value, None, None) == expected