diff --git a/CosmoAPI/__main__.py b/CosmoAPI/__main__.py index fb5b97d..2fc60d0 100644 --- a/CosmoAPI/__main__.py +++ b/CosmoAPI/__main__.py @@ -1,23 +1,24 @@ import argparse +from typing import Dict, Any from .api_io import load_yaml_file from .not_implemented import not_implemented_message -def gen_datavec(config, verbose=False): +def gen_datavec(config: Dict[str, Any], verbose: bool = False) -> None: # Functionality for generating data vector if verbose: print("Verbose mode enabled.") print("Generating data vector with config:", config) -def gen_covariance(config): +def gen_covariance(config: Dict[str, Any]) -> None: # Functionality for generating covariance print(not_implemented_message) -def forecast(config): +def forecast(config: Dict[str, Any]) -> None: # Functionality for forecast print(not_implemented_message) -def main(): +def main() -> None: parser = argparse.ArgumentParser( prog="CosmoAPI", description="CosmoAPI: Cosmology Analysis Pipeline Interface" @@ -83,4 +84,4 @@ def main(): forecast(config) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/CosmoAPI/api_io.py b/CosmoAPI/api_io.py index 4c50145..4669308 100644 --- a/CosmoAPI/api_io.py +++ b/CosmoAPI/api_io.py @@ -1,12 +1,13 @@ import yaml import importlib +from typing import Any, Dict -def load_yaml_file(file_path): +def load_yaml_file(file_path: str) -> Dict[str, Any]: """Helper function to load a YAML file""" with open(file_path, 'r') as file: return yaml.safe_load(file) -def load_metadata_function_class(function_name): +def load_metadata_function_class(function_name: str) -> Any: """ Dynamically load a class based on the 'function' name specified in the YAML file. FIXME: Change the docstrings diff --git a/CosmoAPI/two_point_functions/generate_theory.py b/CosmoAPI/two_point_functions/generate_theory.py index 77d43b5..ebc7f73 100644 --- a/CosmoAPI/two_point_functions/generate_theory.py +++ b/CosmoAPI/two_point_functions/generate_theory.py @@ -6,7 +6,7 @@ from firecrown.metadata_functions import make_all_photoz_bin_combinations import firecrown.likelihood.two_point as tp from firecrown.utils import base_model_from_yaml - +from typing import Dict, Any, List, Tuple from .nz_loader import load_all_nz sys.path.append("..") @@ -14,7 +14,7 @@ from api_io import load_metadata_function_class -def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float): +def generate_ell_theta_array_from_yaml(yaml_data: Dict[str, Any], type_key: str, dtype: type = float) -> np.ndarray: """ Generate a linear or logarithmic array based on the configuration in the YAML data. @@ -39,7 +39,7 @@ def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float): else: raise ValueError(f"Unknown array type: {array_type}") -def load_systematics_factory(probe_systematics): +def load_systematics_factory(probe_systematics: Dict[str, Any]) -> Any: """ Dynamically load a class based on the systematics 'type' specified in the YAML file. @@ -87,7 +87,7 @@ def load_systematics_factory(probe_systematics): except AttributeError as e: raise AttributeError(f"Class '{systematics_type}' not found in module {module_path}: {e}") -def process_probes_load_2pt(yaml_data): +def process_probes_load_2pt(yaml_data: Dict[str, Any]) -> Tuple[Any, List[str]]: """ Process the probes from the YAML data, check if 'function' is the same across probes with 'nz_type', @@ -138,8 +138,8 @@ def process_probes_load_2pt(yaml_data): return loaded_function, nz_type_probes -def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes, - two_point_bins): +def generate_two_point_metadata(yaml_data: Dict[str, Any], two_point_function: Any, two_pt_probes: List[str], + two_point_bins: List[Any]) -> List[Any]: """ Generate the metadata for the two-point functions based on the YAML data. @@ -181,7 +181,7 @@ def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes, raise ValueError("Unknown TwoPointFunction type") return all_two_point_metadata -def prepare_2pt_functions(yaml_data): +def prepare_2pt_functions(yaml_data: Dict[str, Any]) -> Tuple[Any, List[Any]]: # here we call this X because we do not know if it is ell_bins or theta_bins two_point_function, two_pt_probes = process_probes_load_2pt(yaml_data) diff --git a/CosmoAPI/two_point_functions/nz_loader.py b/CosmoAPI/two_point_functions/nz_loader.py index d445b8c..a197540 100644 --- a/CosmoAPI/two_point_functions/nz_loader.py +++ b/CosmoAPI/two_point_functions/nz_loader.py @@ -1,12 +1,13 @@ import importlib import sys +from typing import Dict, List, Any, Type sys.path.append("..") from not_implemented import not_implemented_message _DESC_SCENARIOS = {"LSST_Y10_SOURCE_BIN_COLLECTION", "LSST_Y10_LENS_BIN_COLLECTION", "LSST_Y1_LENS_BIN_COLLECTION", "LSST_Y1_SOURCE_BIN_COLLECTION",} -def _load_nz(yaml_data): +def _load_nz(yaml_data: Dict[str, Any]) -> List[Any]: try: nz_type = yaml_data["nz_type"] except KeyError: @@ -17,7 +18,7 @@ def _load_nz(yaml_data): else: raise NotImplementedError(not_implemented_message) -def load_all_nz(yaml_data): +def load_all_nz(yaml_data: Dict[str, Any]) -> List[Any]: nzs = [] for probe, propr in yaml_data['probes'].items(): if 'nz_type' in propr: @@ -25,7 +26,7 @@ def load_all_nz(yaml_data): nzs += _load_nz(propr) return nzs -def _load_nz_from_module(nz_type): +def _load_nz_from_module(nz_type: str) -> Type: # Define the module path module_path = "firecrown.generators.inferred_galaxy_zdist" @@ -37,4 +38,4 @@ def _load_nz_from_module(nz_type): except ImportError as e: raise ImportError(f"Failed to import module {module_path}: {e}") except AttributeError as e: - raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}") \ No newline at end of file + raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}")