Skip to content

Commit

Permalink
allow front end to do alex/fret cde
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Nov 22, 2024
1 parent 058e1dd commit e6fb85a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 55 deletions.
4 changes: 4 additions & 0 deletions dont_fret/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class Web:
burst_filters: list[BurstFilterItem] = field(default_factory=list)
password: Optional[str] = None

# todo configurable settings
fret_2cde: bool = True # calculate fret_2cde after burst search with default settings
alex_2cde: bool = True


@dataclass
class BurstColor:
Expand Down
93 changes: 69 additions & 24 deletions dont_fret/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
from dataclasses import dataclass
from functools import cached_property, reduce
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
Expand All @@ -9,6 +10,7 @@
import polars as pl

from dont_fret.burst_search import bs_eggeling, return_intersections
from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel
from dont_fret.config.config import BurstColor, DontFRETConfig, cfg
from dont_fret.support import get_binned
from dont_fret.utils import clean_types
Expand Down Expand Up @@ -267,14 +269,16 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts:
indices = pl.DataFrame({"imin": imin, "imax": imax})
# take all photons (up to and including? edges need to be checked!)
b_num = int(2 ** np.ceil(np.log2((np.log2(len(imin))))))
dtype = getattr(pl, f"UInt{b_num}", pl.Int32)
index_dtype = getattr(pl, f"UInt{b_num}", pl.Int32)
bursts = [
self.data[i1 : i2 + 1].with_columns(pl.lit(bi).alias("burst_index").cast(dtype))
self.data[i1 : i2 + 1].with_columns(
pl.lit(bi).alias("burst_index").cast(index_dtype)
)
for bi, (i1, i2) in enumerate(zip(imin, imax))
]
burst_photons = pl.concat(bursts)

bs = Bursts(burst_photons, indices=indices, metadata=self.metadata)
bs = Bursts.from_photons(burst_photons, metadata=self.metadata)

return bs

Expand Down Expand Up @@ -356,27 +360,29 @@ def __len__(self) -> int:
return len(self.time)


class Bursts(object):
@dataclass
class Bursts:
"""
Class which holds a set of bursts.
attrs:
bursts np.array with Burst Objects
burst_data Dataframe with per-burst aggregated data
photon_data Dataframe with per-photon data
"""

# todo add metadata support
# bursts: numpy.typing.ArrayLike[Bursts] ?
def __init__(
self,
burst_data: pl.DataFrame
photon_data: pl.DataFrame
metadata: Optional[dict] = None
cfg: DontFRETConfig = cfg

@classmethod
def from_photons(
cls,
photon_data: pl.DataFrame,
indices: pl.DataFrame,
metadata: Optional[dict] = None,
cfg: DontFRETConfig = cfg,
):
self.photon_data = photon_data
self.indices = indices
self.metadata: dict = metadata or {}
self.cfg = cfg
) -> Bursts:
# todo move to classmethod

# number of photons per stream per burst
agg = [(pl.col("stream") == stream).sum().alias(f"n_{stream}") for stream in cfg.streams]
Expand Down Expand Up @@ -408,7 +414,7 @@ def __init__(

# TODO These should move somewhere else; possibly some second step conversion
# from raw stamps to times
t_unit = self.metadata.get("timestamps_unit", None)
t_unit = metadata.get("timestamps_unit", None)
if t_unit is not None:
columns.extend(
[
Expand All @@ -418,7 +424,7 @@ def __init__(
),
]
)
nanotimes_unit = self.metadata.get("nanotimes_unit", None)
nanotimes_unit = metadata.get("nanotimes_unit", None)
if nanotimes_unit is not None:
columns.extend(
[
Expand All @@ -427,28 +433,65 @@ def __init__(
]
)

self.burst_data = (
self.photon_data.group_by("burst_index", maintain_order=True)
.agg(agg)
.with_columns(columns)
burst_data = (
photon_data.group_by("burst_index", maintain_order=True).agg(agg).with_columns(columns)
)

return Bursts(burst_data, photon_data, metadata=metadata, cfg=cfg)

@classmethod
def load(cls, directory: Path) -> Bursts:
data = pl.read_parquet(directory / "data.pq")
burst_data = pl.read_parquet(directory / "burst_data.pq")
photon_data = pl.read_parquet(directory / "photon_data.pq")
with open(directory / "metadata.json", "r") as f:
metadata = json.load(f)

cfg = DontFRETConfig.from_yaml(directory / "config.yaml")
return Bursts(data, metadata, cfg)
return Bursts(burst_data, photon_data, metadata, cfg)

def save(self, directory: Path) -> None:
directory.mkdir(parents=True, exist_ok=True)
self.photon_data.write_parquet(directory / "data.pq")
self.burst_data.write_parquet(directory / "burst_data.pq")
self.photon_data.write_parquet(directory / "photon_data.pq")

with open(directory / "metadata.json", "w") as f:
json.dump(self.metadata, f)
self.cfg.to_yaml(directory / "config.yaml")

def fret_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts:
assert photons.timestamps_unit
kernel = make_kernel(tau, photons.timestamps_unit)
DA = convolve_stream(photons.data, ["DA"], kernel)
DD = convolve_stream(photons.data, ["DD"], kernel)
kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(DA).alias("DA"), pl.lit(DD).alias("DD")]
)

fret_2cde = compute_fret_2cde(self.photon_data, kde_data)
burst_data = self.burst_data.with_columns(pl.lit(fret_2cde).alias("fret_2cde"))

return Bursts(burst_data, self.photon_data, self.metadata, self.cfg)

def alex_2cde(self, photons: PhotonData, tau: float = 50e-6) -> Bursts:
assert photons.timestamps_unit
kernel = make_kernel(tau, photons.timestamps_unit)
D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel)
A_ex = convolve_stream(photons.data, ["AA"], kernel) # crashed the kernel (sometimes)

kde_data = photons.data.select(
[
pl.col("timestamps"),
pl.col("stream"),
pl.lit(D_ex).alias("D_ex"),
pl.lit(A_ex).alias("A_ex"),
]
)

alex_2cde = compute_alex_2cde(self.photon_data, kde_data)
burst_data = self.burst_data.with_columns(pl.lit(alex_2cde).alias("alex_2cde"))

return Bursts(burst_data, self.photon_data, self.metadata, self.cfg)

def __len__(self) -> int:
"""Number of bursts"""
return len(self.burst_data)
Expand All @@ -459,6 +502,8 @@ def __iter__(self):
@property
def timestamps_unit(self) -> Optional[float]:
"""Multiplication factor to covert timestamps integers to seconds"""
if self.metadata is None:
return None
try:
return self.metadata["timestamps_unit"]
except KeyError:
Expand Down
16 changes: 15 additions & 1 deletion dont_fret/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from tqdm.auto import tqdm

from dont_fret.config import cfg
from dont_fret.config.config import BurstColor, DontFRETConfig
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData
from dont_fret.models import Bursts, PhotonData


def search_and_save(
Expand Down Expand Up @@ -69,3 +70,16 @@ def batch_search_and_save(

for f in tqdm(as_completed(futures), total=len(futures)):
f.result()


def full_search(
photon_data: PhotonData, burst_colors: list[BurstColor], cfg: DontFRETConfig
) -> Bursts:
"""search and alex / fret cde"""
bursts = photon_data.burst_search(burst_colors)
if cfg.web.alex_2cde:
bursts = bursts.alex_2cde(photon_data)
if cfg.web.fret_2cde:
bursts = bursts.fret_2cde(photon_data)

return bursts
44 changes: 15 additions & 29 deletions dont_fret/web/datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,21 @@
Optional,
)

import numpy as np
import polars as pl

from dont_fret.config.config import BurstColor
from dont_fret.config.config import BurstColor, DontFRETConfig, cfg
from dont_fret.fileIO import PhotonFile
from dont_fret.models import Bursts, PhotonData
from dont_fret.process import full_search
from dont_fret.web.methods import get_duration, get_info, make_burst_dataframe
from dont_fret.web.models import BurstNode, PhotonNode


class ThreadedDataManager:
def __init__(self) -> None:
def __init__(self, cfg: DontFRETConfig = cfg) -> None:
self.photon_cache: Dict[uuid.UUID, asyncio.Future[PhotonData]] = {}
self.burst_cache: Dict[tuple[uuid.UUID, str], asyncio.Future[Bursts]] = {}
self.executor = ThreadPoolExecutor(max_workers=4) # todo config wokers
self.running_jobs = {}
self.cfg = cfg

# todo allow passing loop to init
@property
Expand Down Expand Up @@ -78,7 +77,9 @@ async def get_bursts(
self.burst_cache[key] = future

try:
bursts = await self.search(photon_node, burst_colors)
photon_data = await self.get_photons(photon_node)
bursts = await self.run(full_search, photon_data, burst_colors, self.cfg)

future.set_result(bursts)
except Exception as e:
self.burst_cache.pop(key)
Expand Down Expand Up @@ -115,29 +116,15 @@ async def get_bursts_batch(

return results

async def get_dataframe(
self,
photon_nodes: list[PhotonNode],
burst_colors: list[BurstColor],
on_progress: Optional[Callable[[float | bool], None]] = None,
) -> pl.DataFrame:
on_progress = on_progress or (lambda _: None)
on_progress(True)
raise DeprecationWarning("USe get burst node instead")
results = await self.get_bursts_batch(photon_nodes, burst_colors, on_progress)
on_progress(True)

names = [ph_node.name for ph_node in photon_nodes]
lens = [len(burst) for burst in results]

dtype = pl.Enum(categories=names)
filenames = pl.Series(name="filename", values=np.repeat(names, lens), dtype=dtype)

df = pl.concat([b.burst_data for b in results], how="vertical_relaxed").with_columns(
filenames
)
async def alex_2cde(self, photon_node: PhotonNode, bursts: Bursts) -> Bursts:
photons = await self.get_photons(photon_node)
new_bursts = self.run(bursts.alex_2cde, photons)
return await new_bursts

return df
async def fret_2cde(self, photon_node: PhotonNode, bursts: Bursts) -> Bursts:
photons = await self.get_photons(photon_node)
new_bursts = self.run(bursts.fret_2cde, photons)
return await new_bursts

async def get_burst_node(
self,
Expand All @@ -148,7 +135,6 @@ async def get_burst_node(
) -> BurstNode:
bursts = await self.get_bursts_batch(photon_nodes, burst_colors, on_progress=on_progress)
bursts_df = make_burst_dataframe(bursts, names=[ph_node.name for ph_node in photon_nodes])
# burst_df = await self.get_dataframe(photon_nodes, burst_colors, on_progress=on_progress)
info_list = [await self.get_info(node) for node in photon_nodes]

duration = get_duration(info_list)
Expand Down
11 changes: 10 additions & 1 deletion dont_fret/web/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,23 @@ def make_burst_dataframe(
return concat


# hooks?
def make_burst_nodes(
photon_nodes: list[PhotonNode], burst_settings: dict[str, list[BurstColor]]
photon_nodes: list[PhotonNode],
burst_settings: dict[str, list[BurstColor]],
alex_2cde=True,
fret_2cde=True,
) -> list[BurstNode]:
photons = [PhotonData.from_file(PhotonFile(node.file_path)) for node in photon_nodes]
burst_nodes = []
# todo tqdm?
for name, burst_colors in burst_settings.items():
bursts = [photons.burst_search(burst_colors) for photons in photons]
if alex_2cde:
bursts = [b.alex_2cde(photons) for b, photons in zip(bursts, photons)]
if fret_2cde:
bursts = [b.fret_2cde(photons) for b, photons in zip(bursts, photons)]

infos = [get_info(photons) for photons in photons]
duration = get_duration(infos)
df = make_burst_dataframe(bursts, names=[node.name for node in photon_nodes])
Expand Down

0 comments on commit e6fb85a

Please sign in to comment.