diff --git a/dont_fret/models.py b/dont_fret/models.py index b56299b..170b824 100644 --- a/dont_fret/models.py +++ b/dont_fret/models.py @@ -11,7 +11,8 @@ from dont_fret.burst_search import bs_eggeling, return_intersections from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel -from dont_fret.config.config import BurstColor, DontFRETConfig, cfg +from dont_fret.config import cfg as global_cfg +from dont_fret.config.config import BurstColor, DontFRETConfig from dont_fret.support import get_binned from dont_fret.utils import clean_types @@ -34,7 +35,7 @@ class PhotonData: """ def __init__( - self, data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = cfg + self, data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = global_cfg ): self.data = data self.metadata = metadata or {} @@ -224,7 +225,7 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: """ if isinstance(colors, str): - burst_colors = cfg.burst_search[colors] + burst_colors = global_cfg.burst_search[colors] elif isinstance(colors, list): burst_colors = colors else: @@ -380,7 +381,7 @@ def from_photons( cls, photon_data: pl.DataFrame, metadata: Optional[dict] = None, - cfg: DontFRETConfig = cfg, + cfg: DontFRETConfig = global_cfg, ) -> Bursts: # todo move to classmethod @@ -446,7 +447,10 @@ def load(cls, directory: Path) -> Bursts: with open(directory / "metadata.json", "r") as f: metadata = json.load(f) - cfg = DontFRETConfig.from_yaml(directory / "config.yaml") + try: + cfg = DontFRETConfig.from_yaml(directory / "config.yaml") + except FileNotFoundError: + cfg = None return Bursts(burst_data, photon_data, metadata, cfg) def save(self, directory: Path) -> None: @@ -456,7 +460,8 @@ def save(self, directory: Path) -> None: with open(directory / "metadata.json", "w") as f: json.dump(self.metadata, f) - self.cfg.to_yaml(directory / "config.yaml") + if self.cfg is not None: + self.cfg.to_yaml(directory / "config.yaml") def fret_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts: assert photons.timestamps_unit