Skip to content

Commit

Permalink
fix namespace clash in dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Nov 22, 2024
1 parent 63c9136 commit 155e9c3
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions dont_fret/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from dont_fret.burst_search import bs_eggeling, return_intersections
from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel
from dont_fret.config.config import BurstColor, DontFRETConfig, cfg
from dont_fret.config import cfg as global_cfg
from dont_fret.config.config import BurstColor, DontFRETConfig
from dont_fret.support import get_binned
from dont_fret.utils import clean_types

Expand All @@ -34,7 +35,7 @@ class PhotonData:
"""

def __init__(
self, data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = cfg
self, data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = global_cfg
):
self.data = data
self.metadata = metadata or {}
Expand Down Expand Up @@ -224,7 +225,7 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts:
"""

if isinstance(colors, str):
burst_colors = cfg.burst_search[colors]
burst_colors = global_cfg.burst_search[colors]
elif isinstance(colors, list):
burst_colors = colors
else:
Expand Down Expand Up @@ -380,7 +381,7 @@ def from_photons(
cls,
photon_data: pl.DataFrame,
metadata: Optional[dict] = None,
cfg: DontFRETConfig = cfg,
cfg: DontFRETConfig = global_cfg,
) -> Bursts:
# todo move to classmethod

Expand Down Expand Up @@ -446,7 +447,10 @@ def load(cls, directory: Path) -> Bursts:
with open(directory / "metadata.json", "r") as f:
metadata = json.load(f)

cfg = DontFRETConfig.from_yaml(directory / "config.yaml")
try:
cfg = DontFRETConfig.from_yaml(directory / "config.yaml")
except FileNotFoundError:
cfg = None
return Bursts(burst_data, photon_data, metadata, cfg)

def save(self, directory: Path) -> None:
Expand All @@ -456,7 +460,8 @@ def save(self, directory: Path) -> None:

with open(directory / "metadata.json", "w") as f:
json.dump(self.metadata, f)
self.cfg.to_yaml(directory / "config.yaml")
if self.cfg is not None:
self.cfg.to_yaml(directory / "config.yaml")

def fret_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts:
assert photons.timestamps_unit
Expand Down

0 comments on commit 155e9c3

Please sign in to comment.