Skip to content

Commit

Permalink
add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
k-yoshimi committed Oct 30, 2024
1 parent fd44a9c commit d063b71
Show file tree
Hide file tree
Showing 35 changed files with 1,477 additions and 216 deletions.
38 changes: 37 additions & 1 deletion src/odatse/_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,37 @@


class Info:
"""
A class to represent the information structure for the data-analysis software.
"""

base: dict
algorithm: dict
solver: dict
runner: dict

def __init__(self, d: Optional[MutableMapping] = None):
"""
Initialize the Info object.
Parameters:
d (Optional[MutableMapping]): A dictionary to initialize the Info object.
"""
if d is not None:
self.from_dict(d)
else:
self._cleanup()

def from_dict(self, d: MutableMapping) -> None:
"""
Initialize the Info object from a dictionary.
Parameters:
d (MutableMapping): A dictionary containing the information to initialize the Info object.
Raises:
exception.InputError: If any required section is missing in the input dictionary.
"""
for section in ["base", "algorithm", "solver"]:
if section not in d:
raise exception.InputError(
Expand All @@ -58,6 +77,9 @@ def from_dict(self, d: MutableMapping) -> None:
)

def _cleanup(self) -> None:
"""
Reset the Info object to its default state.
"""
self.base = {}
self.base["root_dir"] = Path(".").absolute()
self.base["output_dir"] = self.base["root_dir"]
Expand All @@ -67,6 +89,20 @@ def _cleanup(self) -> None:

@classmethod
def from_file(cls, file_name, fmt="", **kwargs):
"""
Create an Info object from a file.
Parameters:
file_name (str): The name of the file to load the information from.
fmt (str): The format of the file (default is "").
**kwargs: Additional keyword arguments.
Returns:
Info: An Info object initialized with the data from the file.
Raises:
ValueError: If the file format is unsupported.
"""
if fmt == "toml" or fnmatch(file_name.lower(), "*.toml"):
inp = {}
if mpi.rank() == 0:
Expand All @@ -75,4 +111,4 @@ def from_file(cls, file_name, fmt="", **kwargs):
inp = mpi.comm().bcast(inp, root=0)
return cls(inp)
else:
raise ValueError("unsupported file format: {}".format(file_name))
raise ValueError("unsupported file format: {}".format(file_name))
5 changes: 5 additions & 0 deletions src/odatse/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@


def main():
"""
Main function to run the data-analysis software for quantum beam diffraction experiments
on 2D material structures. It parses command-line arguments, loads the input file,
selects the appropriate algorithm and solver, and executes the analysis.
"""
import argparse

parser = argparse.ArgumentParser(
Expand Down
58 changes: 53 additions & 5 deletions src/odatse/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,31 @@
class Run(metaclass=ABCMeta):
def __init__(self, nprocs=None, nthreads=None, comm=None):
"""
Initialize the Run class.
Parameters
----------
nprocs : int
Number of process which one solver uses
Number of processes which one solver uses.
nthreads : int
Number of threads which one solver process uses
Number of threads which one solver process uses.
comm : MPI.Comm
MPI Communicator
MPI Communicator.
"""
self.nprocs = nprocs
self.nthreads = nthreads
self.comm = comm

@abstractmethod
def submit(self, solver):
"""
Abstract method to submit a solver.
Parameters
----------
solver : object
Solver object to be submitted.
"""
pass


Expand All @@ -64,10 +74,18 @@ def __init__(self,
mapping = None,
limitation = None) -> None:
"""
Initialize the Runner class.
Parameters
----------
Solver: odatse.solver.SolverBase object
solver : odatse.solver.SolverBase
Solver object.
info : Optional[odatse.Info]
Information object.
mapping : object, optional
Mapping object.
limitation : object, optional
Limitation object.
"""
self.solver = solver
self.solver_name = solver.name
Expand All @@ -82,7 +100,7 @@ def __init__(self,
else:
# trivial mapping
self.mapping = odatse.util.mapping.TrivialMapping()

if limitation is not None:
self.limitation = limitation
elif "limitation" in info.runner:
Expand All @@ -92,11 +110,38 @@ def __init__(self,
self.limitation = odatse.util.limitation.Unlimited()

def prepare(self, proc_dir: Path):
"""
Prepare the logger with the given process directory.
Parameters
----------
proc_dir : Path
Path to the process directory.
"""
self.logger.prepare(proc_dir)

def submit(
self, x: np.ndarray, args = (), nprocs: int = 1, nthreads: int = 1
) -> float:
"""
Submit the solver with the given parameters.
Parameters
----------
x : np.ndarray
Input array.
args : tuple, optional
Additional arguments.
nprocs : int, optional
Number of processes.
nthreads : int, optional
Number of threads.
Returns
-------
float
Result of the solver evaluation.
"""
if self.limitation.judge(x):
xp = self.mapping(x)
result = self.solver.evaluate(xp, args)
Expand All @@ -106,4 +151,7 @@ def submit(
return result

def post(self) -> None:
"""
Write the logger data.
"""
self.logger.write()
73 changes: 72 additions & 1 deletion src/odatse/algorithm/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@
from mpi4py import MPI

class AlgorithmStatus(IntEnum):
"""Enumeration for the status of the algorithm."""
INIT = 1
PREPARE = 2
RUN = 3

class AlgorithmBase(metaclass=ABCMeta):
"""Base class for algorithms, providing common functionality and structure."""

mpicomm: Optional["MPI.Comm"]
mpisize: int
mpirank: int
Expand All @@ -70,6 +73,13 @@ def __init__(
runner: Optional[odatse.Runner] = None,
run_mode: str = "initial"
) -> None:
"""
Initialize the algorithm with the given information and runner.
:param info: Information object containing algorithm and base parameters.
:param runner: Optional runner object to execute the algorithm.
:param run_mode: Mode in which the algorithm should run.
"""
self.mpicomm = mpi.comm()
self.mpisize = mpi.size()
self.mpirank = mpi.rank()
Expand Down Expand Up @@ -119,6 +129,11 @@ def __init__(
self.set_runner(runner)

def __init_rng(self, info: odatse.Info) -> None:
"""
Initialize the random number generator.
:param info: Information object containing algorithm parameters.
"""
seed = info.algorithm.get("seed", None)
seed_delta = info.algorithm.get("seed_delta", 314159)

Expand All @@ -128,9 +143,17 @@ def __init_rng(self, info: odatse.Info) -> None:
self.rng = np.random.RandomState(seed + self.mpirank * seed_delta)

def set_runner(self, runner: odatse.Runner) -> None:
"""
Set the runner for the algorithm.
:param runner: Runner object to execute the algorithm.
"""
self.runner = runner

def prepare(self) -> None:
"""
Prepare the algorithm for execution.
"""
if self.runner is None:
msg = "Runner is not assigned"
raise RuntimeError(msg)
Expand All @@ -139,9 +162,13 @@ def prepare(self) -> None:

@abstractmethod
def _prepare(self) -> None:
"""Abstract method to be implemented by subclasses for preparation steps."""
pass

def run(self) -> None:
"""
Run the algorithm.
"""
if self.status < AlgorithmStatus.PREPARE:
msg = "algorithm has not prepared yet"
raise RuntimeError(msg)
Expand All @@ -155,9 +182,15 @@ def run(self) -> None:

@abstractmethod
def _run(self) -> None:
"""Abstract method to be implemented by subclasses for running steps."""
pass

def post(self) -> Dict:
"""
Perform post-processing after the algorithm has run.
:return: Dictionary containing post-processing results.
"""
if self.status < AlgorithmStatus.RUN:
msg = "algorithm has not run yet"
raise RuntimeError(msg)
Expand All @@ -169,9 +202,13 @@ def post(self) -> Dict:

@abstractmethod
def _post(self) -> Dict:
"""Abstract method to be implemented by subclasses for post-processing steps."""
pass

def main(self):
"""
Main method to execute the algorithm.
"""
time_sta = time.perf_counter()
self.prepare()
time_end = time.perf_counter()
Expand All @@ -197,6 +234,11 @@ def main(self):
return result

def write_timer(self, filename: Path):
"""
Write the timing information to a file.
:param filename: Path to the file where timing information will be written.
"""
with open(filename, "w") as fw:
fw.write("#in units of seconds\n")

Expand All @@ -214,6 +256,13 @@ def output_file(type):
output_file("post")

def _save_data(self, data, filename="state.pickle", ngen=3) -> None:
"""
Save data to a file with versioning.
:param data: Data to be saved.
:param filename: Name of the file to save the data.
:param ngen: Number of generations for versioning.
"""
try:
fn = Path(filename + ".tmp")
with open(fn, "wb") as f:
Expand All @@ -235,6 +284,12 @@ def _save_data(self, data, filename="state.pickle", ngen=3) -> None:
print("save_state: write to {}".format(filename))

def _load_data(self, filename="state.pickle") -> Dict:
"""
Load data from a file.
:param filename: Name of the file to load the data from.
:return: Dictionary containing the loaded data.
"""
if Path(filename).exists():
try:
fn = Path(filename)
Expand All @@ -250,12 +305,20 @@ def _load_data(self, filename="state.pickle") -> Dict:
return data

def _show_parameters(self):
"""
Show the parameters of the algorithm.
"""
if self.mpirank == 0:
info = flatten_dict(self.info)
for k, v in info.items():
print("{:16s}: {}".format(k, v))

def _check_parameters(self, param=None):
"""
Check the parameters of the algorithm against previous parameters.
:param param: Previous parameters to check against.
"""
info = flatten_dict(self.info)
info_prev = flatten_dict(param)

Expand All @@ -269,6 +332,14 @@ def _check_parameters(self, param=None):

# utility
def flatten_dict(d, parent_key="", separator="."):
"""
Flatten a nested dictionary.
:param d: Dictionary to flatten.
:param parent_key: Key for the parent dictionary.
:param separator: Separator to use between keys.
:return: Flattened dictionary.
"""
items = []
if d:
for key_, val in d.items():
Expand All @@ -277,4 +348,4 @@ def flatten_dict(d, parent_key="", separator="."):
items.extend(flatten_dict(val, key, separator=separator).items())
else:
items.append((key, val))
return dict(items)
return dict(items)
Loading

0 comments on commit d063b71

Please sign in to comment.