Skip to content

Commit

Permalink
Implement typing
Browse files Browse the repository at this point in the history
Fixes #12

Add type hinting to various functions across multiple files for consistency with Firecrown.

* **`CosmoAPI/two_point_functions/generate_theory.py`**
  - Add type hints to `generate_ell_theta_array_from_yaml`, `load_systematics_factory`, `process_probes_load_2pt`, `generate_two_point_metadata`, and `prepare_2pt_functions`.

* **`CosmoAPI/two_point_functions/nz_loader.py`**
  - Add type hints to `_load_nz`, `load_all_nz`, and `_load_nz_from_module`.

* **`CosmoAPI/__main__.py`**
  - Add type hints to `gen_datavec`, `gen_covariance`, `forecast`, and `main`.

* **`CosmoAPI/api_io.py`**
  - Add type hints to `load_yaml_file` and `load_metadata_function_class`.

Important: Change the placeholder `Any` for the actual firecrown types.
  • Loading branch information
arthurmloureiro committed Oct 25, 2024
1 parent 9620304 commit 774b601
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 18 deletions.
11 changes: 6 additions & 5 deletions CosmoAPI/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import argparse
from typing import Dict, Any

from .api_io import load_yaml_file
from .not_implemented import not_implemented_message

def gen_datavec(config, verbose=False):
def gen_datavec(config: Dict[str, Any], verbose: bool = False) -> None:
# Functionality for generating data vector
if verbose:
print("Verbose mode enabled.")
print("Generating data vector with config:", config)

def gen_covariance(config):
def gen_covariance(config: Dict[str, Any]) -> None:
# Functionality for generating covariance
print(not_implemented_message)

def forecast(config):
def forecast(config: Dict[str, Any]) -> None:
# Functionality for forecast
print(not_implemented_message)

def main():
def main() -> None:
parser = argparse.ArgumentParser(
prog="CosmoAPI",
description="CosmoAPI: Cosmology Analysis Pipeline Interface"
Expand Down Expand Up @@ -83,4 +84,4 @@ def main():
forecast(config)

if __name__ == "__main__":
main()
main()
5 changes: 3 additions & 2 deletions CosmoAPI/api_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import yaml
import importlib
from typing import Any, Dict

def load_yaml_file(file_path):
def load_yaml_file(file_path: str) -> Dict[str, Any]:
"""Helper function to load a YAML file"""
with open(file_path, 'r') as file:
return yaml.safe_load(file)

def load_metadata_function_class(function_name):
def load_metadata_function_class(function_name: str) -> Any:
"""
Dynamically load a class based on the 'function' name specified in the YAML file.
FIXME: Change the docstrings
Expand Down
14 changes: 7 additions & 7 deletions CosmoAPI/two_point_functions/generate_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from firecrown.metadata_functions import make_all_photoz_bin_combinations
import firecrown.likelihood.two_point as tp
from firecrown.utils import base_model_from_yaml

from typing import Dict, Any, List, Tuple

from .nz_loader import load_all_nz
sys.path.append("..")
from not_implemented import not_implemented_message
from api_io import load_metadata_function_class


def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float):
def generate_ell_theta_array_from_yaml(yaml_data: Dict[str, Any], type_key: str, dtype: type = float) -> np.ndarray:
"""
Generate a linear or logarithmic array based on the configuration in the YAML data.
Expand All @@ -39,7 +39,7 @@ def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float):
else:
raise ValueError(f"Unknown array type: {array_type}")

def load_systematics_factory(probe_systematics):
def load_systematics_factory(probe_systematics: Dict[str, Any]) -> Any:
"""
Dynamically load a class based on the systematics 'type' specified in the YAML file.
Expand Down Expand Up @@ -87,7 +87,7 @@ def load_systematics_factory(probe_systematics):
except AttributeError as e:
raise AttributeError(f"Class '{systematics_type}' not found in module {module_path}: {e}")

def process_probes_load_2pt(yaml_data):
def process_probes_load_2pt(yaml_data: Dict[str, Any]) -> Tuple[Any, List[str]]:
"""
Process the probes from the YAML data, check if 'function'
is the same across probes with 'nz_type',
Expand Down Expand Up @@ -138,8 +138,8 @@ def process_probes_load_2pt(yaml_data):

return loaded_function, nz_type_probes

def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes,
two_point_bins):
def generate_two_point_metadata(yaml_data: Dict[str, Any], two_point_function: Any, two_pt_probes: List[str],
two_point_bins: List[Any]) -> List[Any]:
"""
Generate the metadata for the two-point functions based on the YAML data.
Expand Down Expand Up @@ -181,7 +181,7 @@ def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes,
raise ValueError("Unknown TwoPointFunction type")
return all_two_point_metadata

def prepare_2pt_functions(yaml_data):
def prepare_2pt_functions(yaml_data: Dict[str, Any]) -> Tuple[Any, List[Any]]:
# here we call this X because we do not know if it is ell_bins or theta_bins
two_point_function, two_pt_probes = process_probes_load_2pt(yaml_data)

Expand Down
9 changes: 5 additions & 4 deletions CosmoAPI/two_point_functions/nz_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import importlib
import sys
from typing import Dict, List, Any, Type
sys.path.append("..")
from not_implemented import not_implemented_message

_DESC_SCENARIOS = {"LSST_Y10_SOURCE_BIN_COLLECTION", "LSST_Y10_LENS_BIN_COLLECTION",
"LSST_Y1_LENS_BIN_COLLECTION", "LSST_Y1_SOURCE_BIN_COLLECTION",}

def _load_nz(yaml_data):
def _load_nz(yaml_data: Dict[str, Any]) -> List[Any]:
try:
nz_type = yaml_data["nz_type"]
except KeyError:
Expand All @@ -17,15 +18,15 @@ def _load_nz(yaml_data):
else:
raise NotImplementedError(not_implemented_message)

def load_all_nz(yaml_data):
def load_all_nz(yaml_data: Dict[str, Any]) -> List[Any]:
nzs = []
for probe, propr in yaml_data['probes'].items():
if 'nz_type' in propr:
#print(propr['nz_type'])
nzs += _load_nz(propr)
return nzs

def _load_nz_from_module(nz_type):
def _load_nz_from_module(nz_type: str) -> Type:
# Define the module path
module_path = "firecrown.generators.inferred_galaxy_zdist"

Expand All @@ -37,4 +38,4 @@ def _load_nz_from_module(nz_type):
except ImportError as e:
raise ImportError(f"Failed to import module {module_path}: {e}")
except AttributeError as e:
raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}")
raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}")

0 comments on commit 774b601

Please sign in to comment.