diff --git a/dont_fret/config/config.py b/dont_fret/config/config.py index 93dc0ce..ede8566 100644 --- a/dont_fret/config/config.py +++ b/dont_fret/config/config.py @@ -48,6 +48,10 @@ class Web: burst_filters: list[BurstFilterItem] = field(default_factory=list) password: Optional[str] = None + # todo configurable settings + fret_2cde: bool = True # calculate fret_2cde after burst search with default settings + alex_2cde: bool = True + @dataclass class BurstColor: diff --git a/dont_fret/models.py b/dont_fret/models.py index b9f13cb..624614d 100644 --- a/dont_fret/models.py +++ b/dont_fret/models.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from dataclasses import dataclass from functools import cached_property, reduce from pathlib import Path from typing import TYPE_CHECKING, Optional, Union @@ -9,6 +10,7 @@ import polars as pl from dont_fret.burst_search import bs_eggeling, return_intersections +from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel from dont_fret.config.config import BurstColor, DontFRETConfig, cfg from dont_fret.support import get_binned from dont_fret.utils import clean_types @@ -267,14 +269,16 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: indices = pl.DataFrame({"imin": imin, "imax": imax}) # take all photons (up to and including? edges need to be checked!) b_num = int(2 ** np.ceil(np.log2((np.log2(len(imin)))))) - dtype = getattr(pl, f"UInt{b_num}", pl.Int32) + index_dtype = getattr(pl, f"UInt{b_num}", pl.Int32) bursts = [ - self.data[i1 : i2 + 1].with_columns(pl.lit(bi).alias("burst_index").cast(dtype)) + self.data[i1 : i2 + 1].with_columns( + pl.lit(bi).alias("burst_index").cast(index_dtype) + ) for bi, (i1, i2) in enumerate(zip(imin, imax)) ] burst_photons = pl.concat(bursts) - bs = Bursts(burst_photons, indices=indices, metadata=self.metadata) + bs = Bursts.from_photons(burst_photons, metadata=self.metadata) return bs @@ -356,27 +360,29 @@ def __len__(self) -> int: return len(self.time) -class Bursts(object): +@dataclass +class Bursts: """ Class which holds a set of bursts. attrs: - bursts np.array with Burst Objects + burst_data Dataframe with per-burst aggregated data + photon_data Dataframe with per-photon data """ - # todo add metadata support - # bursts: numpy.typing.ArrayLike[Bursts] ? - def __init__( - self, + burst_data: pl.DataFrame + photon_data: pl.DataFrame + metadata: Optional[dict] = None + cfg: DontFRETConfig = cfg + + @classmethod + def from_photons( + cls, photon_data: pl.DataFrame, - indices: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = cfg, - ): - self.photon_data = photon_data - self.indices = indices - self.metadata: dict = metadata or {} - self.cfg = cfg + ) -> Bursts: + # todo move to classmethod # number of photons per stream per burst agg = [(pl.col("stream") == stream).sum().alias(f"n_{stream}") for stream in cfg.streams] @@ -408,7 +414,7 @@ def __init__( # TODO These should move somewhere else; possibly some second step conversion # from raw stamps to times - t_unit = self.metadata.get("timestamps_unit", None) + t_unit = metadata.get("timestamps_unit", None) if t_unit is not None: columns.extend( [ @@ -418,7 +424,7 @@ def __init__( ), ] ) - nanotimes_unit = self.metadata.get("nanotimes_unit", None) + nanotimes_unit = metadata.get("nanotimes_unit", None) if nanotimes_unit is not None: columns.extend( [ @@ -427,28 +433,65 @@ def __init__( ] ) - self.burst_data = ( - self.photon_data.group_by("burst_index", maintain_order=True) - .agg(agg) - .with_columns(columns) + burst_data = ( + photon_data.group_by("burst_index", maintain_order=True).agg(agg).with_columns(columns) ) + return Bursts(burst_data, photon_data, metadata=metadata, cfg=cfg) + @classmethod def load(cls, directory: Path) -> Bursts: - data = pl.read_parquet(directory / "data.pq") + burst_data = pl.read_parquet(directory / "burst_data.pq") + photon_data = pl.read_parquet(directory / "photon_data.pq") with open(directory / "metadata.json", "r") as f: metadata = json.load(f) cfg = DontFRETConfig.from_yaml(directory / "config.yaml") - return Bursts(data, metadata, cfg) + return Bursts(burst_data, photon_data, metadata, cfg) def save(self, directory: Path) -> None: directory.mkdir(parents=True, exist_ok=True) - self.photon_data.write_parquet(directory / "data.pq") + self.burst_data.write_parquet(directory / "burst_data.pq") + self.photon_data.write_parquet(directory / "photon_data.pq") + with open(directory / "metadata.json", "w") as f: json.dump(self.metadata, f) self.cfg.to_yaml(directory / "config.yaml") + def fret_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts: + assert photons.timestamps_unit + kernel = make_kernel(tau, photons.timestamps_unit) + DA = convolve_stream(photons.data, ["DA"], kernel) + DD = convolve_stream(photons.data, ["DD"], kernel) + kde_data = photons.data.select( + [pl.col("timestamps"), pl.col("stream"), pl.lit(DA).alias("DA"), pl.lit(DD).alias("DD")] + ) + + fret_2cde = compute_fret_2cde(self.photon_data, kde_data) + burst_data = self.burst_data.with_columns(pl.lit(fret_2cde).alias("fret_2cde")) + + return Bursts(burst_data, self.photon_data, self.metadata, self.cfg) + + def alex_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts: + assert photons.timestamps_unit + kernel = make_kernel(tau, photons.timestamps_unit) + D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel) + A_ex = convolve_stream(photons.data, ["AA"], kernel) # crashed the kernel (sometimes) + + kde_data = photons.data.select( + [ + pl.col("timestamps"), + pl.col("stream"), + pl.lit(D_ex).alias("D_ex"), + pl.lit(A_ex).alias("A_ex"), + ] + ) + + alex_2cde = compute_alex_2cde(self.photon_data, kde_data) + burst_data = self.burst_data.with_columns(pl.lit(alex_2cde).alias("alex_2cde")) + + return Bursts(burst_data, self.photon_data, self.metadata, self.cfg) + def __len__(self) -> int: """Number of bursts""" return len(self.burst_data) @@ -459,6 +502,8 @@ def __iter__(self): @property def timestamps_unit(self) -> Optional[float]: """Multiplication factor to covert timestamps integers to seconds""" + if self.metadata is None: + return None try: return self.metadata["timestamps_unit"] except KeyError: diff --git a/dont_fret/process.py b/dont_fret/process.py index 421fe29..80e586f 100644 --- a/dont_fret/process.py +++ b/dont_fret/process.py @@ -8,8 +8,9 @@ from tqdm.auto import tqdm from dont_fret.config import cfg +from dont_fret.config.config import BurstColor, DontFRETConfig from dont_fret.fileIO import PhotonFile -from dont_fret.models import PhotonData +from dont_fret.models import Bursts, PhotonData def search_and_save( @@ -69,3 +70,16 @@ def batch_search_and_save( for f in tqdm(as_completed(futures), total=len(futures)): f.result() + + +def full_search( + photon_data: PhotonData, burst_colors: list[BurstColor], cfg: DontFRETConfig +) -> Bursts: + """search and alex / fret cde""" + bursts = photon_data.burst_search(burst_colors) + if cfg.web.alex_2cde: + bursts = bursts.alex_2cde(photon_data) + if cfg.web.fret_2cde: + bursts = bursts.fret_2cde(photon_data) + + return bursts diff --git a/dont_fret/web/datamanager.py b/dont_fret/web/datamanager.py index 34d44a1..85662de 100644 --- a/dont_fret/web/datamanager.py +++ b/dont_fret/web/datamanager.py @@ -11,22 +11,21 @@ Optional, ) -import numpy as np -import polars as pl - -from dont_fret.config.config import BurstColor +from dont_fret.config.config import BurstColor, DontFRETConfig, cfg from dont_fret.fileIO import PhotonFile from dont_fret.models import Bursts, PhotonData +from dont_fret.process import full_search from dont_fret.web.methods import get_duration, get_info, make_burst_dataframe from dont_fret.web.models import BurstNode, PhotonNode class ThreadedDataManager: - def __init__(self) -> None: + def __init__(self, cfg: DontFRETConfig = cfg) -> None: self.photon_cache: Dict[uuid.UUID, asyncio.Future[PhotonData]] = {} self.burst_cache: Dict[tuple[uuid.UUID, str], asyncio.Future[Bursts]] = {} self.executor = ThreadPoolExecutor(max_workers=4) # todo config wokers self.running_jobs = {} + self.cfg = cfg # todo allow passing loop to init @property @@ -78,7 +77,9 @@ async def get_bursts( self.burst_cache[key] = future try: - bursts = await self.search(photon_node, burst_colors) + photon_data = await self.get_photons(photon_node) + bursts = await self.run(full_search, photon_data, burst_colors, self.cfg) + future.set_result(bursts) except Exception as e: self.burst_cache.pop(key) @@ -115,29 +116,15 @@ async def get_bursts_batch( return results - async def get_dataframe( - self, - photon_nodes: list[PhotonNode], - burst_colors: list[BurstColor], - on_progress: Optional[Callable[[float | bool], None]] = None, - ) -> pl.DataFrame: - on_progress = on_progress or (lambda _: None) - on_progress(True) - raise DeprecationWarning("USe get burst node instead") - results = await self.get_bursts_batch(photon_nodes, burst_colors, on_progress) - on_progress(True) - - names = [ph_node.name for ph_node in photon_nodes] - lens = [len(burst) for burst in results] - - dtype = pl.Enum(categories=names) - filenames = pl.Series(name="filename", values=np.repeat(names, lens), dtype=dtype) - - df = pl.concat([b.burst_data for b in results], how="vertical_relaxed").with_columns( - filenames - ) + async def alex_2cde(self, photon_node: PhotonNode, bursts: Bursts) -> Bursts: + photons = await self.get_photons(photon_node) + new_bursts = self.run(bursts.alex_2cde, photons) + return await new_bursts - return df + async def fret_2cde(self, photon_node: PhotonNode, bursts: Bursts) -> Bursts: + photons = await self.get_photons(photon_node) + new_bursts = self.run(bursts.fret_2cde, photons) + return await new_bursts async def get_burst_node( self, @@ -148,7 +135,6 @@ async def get_burst_node( ) -> BurstNode: bursts = await self.get_bursts_batch(photon_nodes, burst_colors, on_progress=on_progress) bursts_df = make_burst_dataframe(bursts, names=[ph_node.name for ph_node in photon_nodes]) - # burst_df = await self.get_dataframe(photon_nodes, burst_colors, on_progress=on_progress) info_list = [await self.get_info(node) for node in photon_nodes] duration = get_duration(info_list) diff --git a/dont_fret/web/methods.py b/dont_fret/web/methods.py index 86bd229..50315a3 100644 --- a/dont_fret/web/methods.py +++ b/dont_fret/web/methods.py @@ -45,14 +45,23 @@ def make_burst_dataframe( return concat +# hooks? def make_burst_nodes( - photon_nodes: list[PhotonNode], burst_settings: dict[str, list[BurstColor]] + photon_nodes: list[PhotonNode], + burst_settings: dict[str, list[BurstColor]], + alex_2cde=True, + fret_2cde=True, ) -> list[BurstNode]: photons = [PhotonData.from_file(PhotonFile(node.file_path)) for node in photon_nodes] burst_nodes = [] # todo tqdm? for name, burst_colors in burst_settings.items(): bursts = [photons.burst_search(burst_colors) for photons in photons] + if alex_2cde: + bursts = [b.alex_2cde(photons) for b, photons in zip(bursts, photons)] + if fret_2cde: + bursts = [b.fret_2cde(photons) for b, photons in zip(bursts, photons)] + infos = [get_info(photons) for photons in photons] duration = get_duration(infos) df = make_burst_dataframe(bursts, names=[node.name for node in photon_nodes])