diff --git a/flepimop/gempyor_pkg/src/gempyor/seeding.py b/flepimop/gempyor_pkg/src/gempyor/seeding.py index 737e26aab..a06763b4a 100644 --- a/flepimop/gempyor_pkg/src/gempyor/seeding.py +++ b/flepimop/gempyor_pkg/src/gempyor/seeding.py @@ -6,6 +6,7 @@ from datetime import date import logging from typing import Any +import warnings import confuse import numba as nb @@ -23,9 +24,6 @@ logger = logging.getLogger(__name__) -## TODO: ideally here path_prefix should not be used and all files loaded from modinf - - # Internal functionality def _DataFrame2NumbaDict( df: pd.DataFrame, @@ -114,7 +112,22 @@ def _DataFrame2NumbaDict( # Exported functionality class Seeding(SimulationComponent): + """ + Class to handle the seeding of the simulation. + + Attributes: + seeding_config: The configuration for the seeding. + path_prefix: The path prefix to use when reading files. + """ + def __init__(self, config: confuse.ConfigView, path_prefix: str = "."): + """ + Initialize a seeding instance. + + Args: + config: The configuration for the seeding. + path_prefix: The path prefix to use when reading files. + """ self.seeding_config = config self.path_prefix = path_prefix @@ -127,6 +140,27 @@ def get_from_config( tf: date, input_filename: str | None, ) -> tuple[nb.typed.Dict, npt.NDArray[np.number]]: + """ + Get seeding data from the configuration. + + Args: + compartments: The compartments for the simulation. + subpop_struct: The subpopulation structure for the simulation. + n_days: The number of days in the simulation. + ti: The start date of the simulation. + tf: The end date of the simulation. + input_filename: The input filename to use for seeding data. Only used if + the seeding method is 'FolderDraw'. + + Returns: + A tuple containing the seeding data as a Numba dictionary and the seeding + amounts as a Numpy array. The seeding data is a dictionary with the + following keys: + - "seeding_sources": The source compartments for the seeding. + - "seeding_destinations": The destination compartments for the seeding. + - "seeding_subpops": The subpopulations for the seeding. + - "day_start_idx": The start index for each day in the seeding data. + """ method = "NoSeeding" if self.seeding_config is not None and "method" in self.seeding_config.keys(): method = self.seeding_config["method"].as_str() @@ -166,16 +200,10 @@ def get_from_config( else: raise ValueError(f"Unknown seeding method given, '{method}'.") - # Sorting by date is very important here for the seeding format necessary !!!! - # print(seeding.shape) + # Sorting by date is important for the seeding format seeding = seeding.sort_values(by="date", axis="index").reset_index() - # print(seeding) mask = (seeding["date"].dt.date > ti) & (seeding["date"].dt.date <= tf) seeding = seeding.loc[mask].reset_index() - # print(seeding.shape) - # print(seeding) - - # TODO: print. amounts = np.zeros(len(seeding)) if method == "PoissonDistributed": @@ -195,11 +223,37 @@ def get_from_config( def get_from_file( self, *args: Any, **kwargs: Any ) -> tuple[nb.typed.Dict, npt.NDArray[np.number]]: - """only difference with draw seeding is that the sim_id is now sim_id2load""" + """ + This method is deprecated. Use `get_from_config` instead. + + Args: + *args: Positional arguments to pass to `get_from_config`. + **kwargs: Keyword arguments to pass to `get_from_config`. + + Returns: + The result of `get_from_config`. + """ + warnings.warn( + "The 'get_from_file' method is deprecated. Use 'get_from_config' instead.", + DeprecationWarning, + ) return self.get_from_config(*args, **kwargs) -def SeedingFactory(config: confuse.ConfigView, path_prefix: str = "."): +def SeedingFactory(config: confuse.ConfigView, path_prefix: str = ".") -> Seeding: + """ + Create a Seeding instance based on the given configuration. + + This function will use the given configuration to either lookup a plugin class for + the seeding instance or fallback to the default Seeding class. + + Args: + config: The configuration for the seeding. + path_prefix: The path prefix to use when reading files. + + Returns: + A Seeding instance. + """ if config is not None and "method" in config.keys(): if config["method"].as_str() == "plugin": klass = utils.search_and_import_plugins_class(