Skip to content

Commit

Permalink
reformat docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
aoymt committed Oct 31, 2024
1 parent 1fd6144 commit 50acb05
Show file tree
Hide file tree
Showing 19 changed files with 664 additions and 297 deletions.
48 changes: 31 additions & 17 deletions src/odatse/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ def __init__(self, d: Optional[MutableMapping] = None):
"""
Initialize the Info object.
Parameters:
d (Optional[MutableMapping]): A dictionary to initialize the Info object.
Parameters
----------
d : MutableMapping (optional)
A dictionary to initialize the Info object.
"""
if d is not None:
self.from_dict(d)
Expand All @@ -51,11 +53,15 @@ def from_dict(self, d: MutableMapping) -> None:
"""
Initialize the Info object from a dictionary.
Parameters:
d (MutableMapping): A dictionary containing the information to initialize the Info object.
Parameters
----------
d : MutableMapping
A dictionary containing the information to initialize the Info object.
Raises:
exception.InputError: If any required section is missing in the input dictionary.
Raises
------
exception.InputError
If any required section is missing in the input dictionary.
"""
for section in ["base", "algorithm", "solver"]:
if section not in d:
Expand Down Expand Up @@ -92,16 +98,24 @@ def from_file(cls, file_name, fmt="", **kwargs):
"""
Create an Info object from a file.
Parameters:
file_name (str): The name of the file to load the information from.
fmt (str): The format of the file (default is "").
**kwargs: Additional keyword arguments.
Returns:
Info: An Info object initialized with the data from the file.
Raises:
ValueError: If the file format is unsupported.
Parameters
----------
file_name : str
The name of the file to load the information from.
fmt : str
The format of the file (default is "").
**kwargs
Additional keyword arguments.
Returns
-------
Info
An Info object initialized with the data from the file.
Raises
------
ValueError
If the file format is unsupported.
"""
if fmt == "toml" or fnmatch(file_name.lower(), "*.toml"):
inp = {}
Expand All @@ -111,4 +125,4 @@ def from_file(cls, file_name, fmt="", **kwargs):
inp = mpi.comm().bcast(inp, root=0)
return cls(inp)
else:
raise ValueError("unsupported file format: {}".format(file_name))
raise ValueError("unsupported file format: {}".format(file_name))
6 changes: 3 additions & 3 deletions src/odatse/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

def main():
"""
Main function to run the data-analysis software for quantum beam diffraction experiments
on 2D material structures. It parses command-line arguments, loads the input file,
selects the appropriate algorithm and solver, and executes the analysis.
Main function to run the data-analysis software for quantum beam diffraction experiments
on 2D material structures. It parses command-line arguments, loads the input file,
selects the appropriate algorithm and solver, and executes the analysis.
"""
import argparse

Expand Down
77 changes: 59 additions & 18 deletions src/odatse/algorithm/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,14 @@ def __init__(
"""
Initialize the algorithm with the given information and runner.
:param info: Information object containing algorithm and base parameters.
:param runner: Optional runner object to execute the algorithm.
:param run_mode: Mode in which the algorithm should run.
Parameters
----------
info : Info
Information object containing algorithm and base parameters.
runner : Runner (optional)
Optional runner object to execute the algorithm.
run_mode : str
Mode in which the algorithm should run.
"""
self.mpicomm = mpi.comm()
self.mpisize = mpi.size()
Expand Down Expand Up @@ -132,7 +137,10 @@ def __init_rng(self, info: odatse.Info) -> None:
"""
Initialize the random number generator.
:param info: Information object containing algorithm parameters.
Parameters
----------
info : Info
Information object containing algorithm parameters.
"""
seed = info.algorithm.get("seed", None)
seed_delta = info.algorithm.get("seed_delta", 314159)
Expand All @@ -146,7 +154,10 @@ def set_runner(self, runner: odatse.Runner) -> None:
"""
Set the runner for the algorithm.
:param runner: Runner object to execute the algorithm.
Parameters
----------
runner : Runner
Runner object to execute the algorithm.
"""
self.runner = runner

Expand Down Expand Up @@ -189,7 +200,10 @@ def post(self) -> Dict:
"""
Perform post-processing after the algorithm has run.
:return: Dictionary containing post-processing results.
Returns
-------
Dict
Dictionary containing post-processing results.
"""
if self.status < AlgorithmStatus.RUN:
msg = "algorithm has not run yet"
Expand Down Expand Up @@ -237,7 +251,10 @@ def write_timer(self, filename: Path):
"""
Write the timing information to a file.
:param filename: Path to the file where timing information will be written.
Parameters
----------
filename : Path
Path to the file where timing information will be written.
"""
with open(filename, "w") as fw:
fw.write("#in units of seconds\n")
Expand All @@ -259,9 +276,14 @@ def _save_data(self, data, filename="state.pickle", ngen=3) -> None:
"""
Save data to a file with versioning.
:param data: Data to be saved.
:param filename: Name of the file to save the data.
:param ngen: Number of generations for versioning.
Parameters
----------
data
Data to be saved.
filename
Name of the file to save the data.
ngen : int, default: 3
Number of generations for versioning.
"""
try:
fn = Path(filename + ".tmp")
Expand All @@ -287,8 +309,15 @@ def _load_data(self, filename="state.pickle") -> Dict:
"""
Load data from a file.
:param filename: Name of the file to load the data from.
:return: Dictionary containing the loaded data.
Parameters
----------
filename
Name of the file to load the data from.
Returns
-------
Dict
Dictionary containing the loaded data.
"""
if Path(filename).exists():
try:
Expand Down Expand Up @@ -317,7 +346,10 @@ def _check_parameters(self, param=None):
"""
Check the parameters of the algorithm against previous parameters.
:param param: Previous parameters to check against.
Parameters
----------
param (optional)
Previous parameters to check against.
"""
info = flatten_dict(self.info)
info_prev = flatten_dict(param)
Expand All @@ -335,10 +367,19 @@ def flatten_dict(d, parent_key="", separator="."):
"""
Flatten a nested dictionary.
:param d: Dictionary to flatten.
:param parent_key: Key for the parent dictionary.
:param separator: Separator to use between keys.
:return: Flattened dictionary.
Parameters
----------
d
Dictionary to flatten.
parent_key : str, default : ""
Key for the parent dictionary.
separator : str, default : "."
Separator to use between keys.
Returns
-------
dict
Flattened dictionary.
"""
items = []
if d:
Expand All @@ -348,4 +389,4 @@ def flatten_dict(d, parent_key="", separator="."):
items.extend(flatten_dict(val, key, separator=separator).items())
else:
items.append((key, val))
return dict(items)
return dict(items)
34 changes: 25 additions & 9 deletions src/odatse/algorithm/mapper_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,16 @@ def __init__(self, info: odatse.Info,
"""
Initialize the Algorithm instance.
:param info: Information object containing algorithm parameters.
:param runner: Optional runner object for submitting tasks.
:param domain: Optional domain object, defaults to MeshGrid.
:param run_mode: Mode to run the algorithm, defaults to "initial".
Parameters
----------
info : Info
Information object containing algorithm parameters.
runner : Runner
Optional runner object for submitting tasks.
domain :
Optional domain object, defaults to MeshGrid.
run_mode : str
Mode to run the algorithm, defaults to "initial".
"""
super().__init__(info=info, runner=runner, run_mode=run_mode)

Expand Down Expand Up @@ -174,7 +180,10 @@ def _post(self) -> Dict:
"""
Post-process the results and gather data from all MPI ranks.
:return: Dictionary of results.
Returns
-------
Dict
Dictionary of results.
"""
if self.mpisize > 1:
fx_lists = self.mpicomm.allgather(self.fx_list)
Expand All @@ -196,7 +205,10 @@ def _save_state(self, filename) -> None:
"""
Save the current state of the algorithm to a file.
:param filename: The name of the file to save the state to.
Parameters
----------
filename
The name of the file to save the state to.
"""
data = {
"mpisize": self.mpisize,
Expand All @@ -212,8 +224,12 @@ def _load_state(self, filename, restore_rng=True):
"""
Load the state of the algorithm from a file.
:param filename: The name of the file to load the state from.
:param restore_rng: Whether to restore the random number generator state.
Parameters
----------
filename
The name of the file to load the state from.
restore_rng : bool
Whether to restore the random number generator state.
"""
data = self._load_data(filename)
if not data:
Expand All @@ -230,4 +246,4 @@ def _load_state(self, filename, restore_rng=True):

self.fx_list = data["fx_list"]

assert len(self.mesh_list) == data["mesh_size"]
assert len(self.mesh_list) == data["mesh_size"]
30 changes: 22 additions & 8 deletions src/odatse/algorithm/min_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,16 @@ def __init__(self, info: odatse.Info,
"""
Initialize the Algorithm class.
:param info: Information object containing algorithm settings.
:param runner: Runner object for submitting jobs.
:param domain: Domain object defining the search space.
:param run_mode: Mode of running the algorithm.
Parameters
----------
info : Info
Information object containing algorithm settings.
runner : Runner
Runner object for submitting jobs.
domain :
Domain object defining the search space.
run_mode : str
Mode of running the algorithm.
"""
super().__init__(info=info, runner=runner, run_mode=run_mode)

Expand Down Expand Up @@ -132,9 +138,17 @@ def _f_calc(x_list: np.ndarray, iset) -> float:
"""
Calculate the objective function value.
:param x_list: List of variables.
:param iset: Set index.
:return: Objective function value.
Parameters
----------
x_list : np.ndarray
List of variables.
iset :
Set index.
Returns
-------
float
Objective function value.
"""
# check if within region -> boundary option in minimize
# note: 'bounds' option supported in scipy >= 1.7.0
Expand Down Expand Up @@ -259,4 +273,4 @@ def _post(self):
for x, y in zip(label_list, x0s[idx]):
fp.write(f"initial {x} = {y}\n")

return {"x": xs[idx], "fx": fxs[idx], "x0": x0s[idx]}
return {"x": xs[idx], "fx": fxs[idx], "x0": x0s[idx]}
Loading

0 comments on commit 50acb05

Please sign in to comment.