diff --git a/dont_fret/channel_kde.py b/dont_fret/channel_kde.py new file mode 100644 index 0000000..fac0a3a --- /dev/null +++ b/dont_fret/channel_kde.py @@ -0,0 +1,142 @@ +""" +Module for channel based kernel density estimation + +This is based on: +Disentangling Subpopulations in Single-Molecule FRET and ALEX Experiments with Photon Distribution Analysis +https://doi.org/10.1016/j.bpj.2011.11.4025 + +Please cite the paper when using this module. + +""" + +import warnings +from typing import Optional + +import numpy as np +import polars as pl +from numba import float64, int64, jit, types + +from dont_fret.expr import is_in_expr + + +def alex_2cde( + burst_photons: pl.DataFrame, + kde_rates: pl.DataFrame, + dex_streams: Optional[list[str]] = None, + aex_streams: Optional[list[str]] = None, +) -> pl.Series: + """ + burst_photons: Dataframe with columns: timestamps, stream, burst_index + kde_rates: Dataframe with columns: timestamps, D_ex, A_ex. + dex_streams: list of photon streams which are donor excitation (default: DD, DA) + aex_streams: list of photon streams which are acceptor excitation (default: AA) + """ + dex_streams = dex_streams if dex_streams else ["DD", "DA"] + aex_streams = aex_streams if aex_streams else ["AA"] + + f_dex = is_in_expr("stream", dex_streams) + f_aex = is_in_expr("stream", aex_streams) + + # equivalent to (but faster): + # joined_df = burst_photons.join(kde_rates, on=['timestamps', 'stream'], how='inner') + j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left") + joined_df = j1.filter(pl.col("stream") == pl.col("stream_right")) + # joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter( + # pl.col("stream") == pl.col("stream_right") + # ) + + b_df = joined_df.select( + [ + pl.col("burst_index"), + pl.col("stream"), + (pl.col("A_ex") / pl.col("D_ex")).alias("ratio_AD"), + ] + ) + + # tomov et al eqn 10 and 11 + df_f_dex = b_df.filter(f_dex) + agg_dex = df_f_dex.group_by("burst_index", maintain_order=True).agg( + [pl.col("ratio_AD").sum(), pl.len().alias("N_dex")] + ) + + df_f_aex = b_df.filter(f_aex) + agg_dax = df_f_aex.group_by("burst_index", maintain_order=True).agg( + [(1 / pl.col("ratio_AD")).sum().alias("ratio_DA"), pl.len().alias("N_aex")] + ) + + combined = pl.concat([agg_dex, agg_dax], how="align") + + # tomov et al eqn 12 + ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col( + "ratio_DA" + ) + ax_2cde_norm = pl.lit(100) - pl.lit(50) * ax_2cde_bracket + + alex_2cde = combined.select(ax_2cde_norm.alias("alex_2cde")).to_series() + + return alex_2cde + + +def make_kernel( + tau: float, timestamps_unit: float, domain_size: int = 10, kernel="laplace" +) -> np.ndarray: + window_size = domain_size * (tau / timestamps_unit) + window_size_even_int = 2 * round(window_size / 2) + + # check that rounding error isnt too large + rel_dev = (window_size - window_size_even_int) / window_size + if np.abs(rel_dev) > 0.01: + warnings.warn( + "Kernel window size deviation from rounding larger than 1 percent. Choose a smaller `tau` with respect to `timestamps_unit'" + ) + + t_eval = np.linspace(-domain_size / 2, domain_size / 2, window_size_even_int + 1, endpoint=True) + kernel = np.exp(-np.abs(t_eval)) + + return kernel + + +def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray) -> np.ndarray: + f_expr = is_in_expr("stream", streams) + + df = data.filter(f_expr) + # TODO warn on copy + event_times = df["timestamps"].to_numpy(allow_copy=True) + eval_times = data["timestamps"].to_numpy(allow_copy=True) + + return async_convolve(event_times, eval_times, kernel) + + +@jit( + float64[:]( + types.Array(int64, 1, "C", readonly=True), + types.Array(int64, 1, "C", readonly=True), + types.Array(float64, 1, "C", readonly=True), + ), + nopython=True, + cache=True, + nogil=True, +) +def async_convolve(event_times, eval_times, kernel): + """convolve integer timestamps with a kernel""" + + i_lower = 0 + i_upper = 0 + + window_half_size = len(kernel) // 2 + result = np.zeros_like(eval_times, dtype=np.float64) + + for i, time in enumerate(eval_times): + while event_times[i_upper] < time + window_half_size: + if i_upper == len(event_times): + break + i_upper += 1 + + while event_times[i_lower] < time - window_half_size: + i_lower += 1 + + for j in range(i_lower, i_upper): + idx = event_times[j] - time + window_half_size + result[i] += kernel[idx] + + return result diff --git a/dont_fret/expr.py b/dont_fret/expr.py index 123dc07..a175424 100644 --- a/dont_fret/expr.py +++ b/dont_fret/expr.py @@ -32,6 +32,17 @@ def reduce_or(exprs: list[pl.Expr]) -> pl.Expr: return reduce(or_, exprs) +def is_in_expr(field: str, values: list) -> pl.Expr: + """generate an polars expression equivalent to pl.col(field).is_in(values) by chaining equal + and or operations together + + this is (sometimes?) faster to execute + """ + + exprs = [pl.col(field) == value for value in values] + return reduce_or(exprs) + + def parse_yaml_expressions(yaml_content: str) -> Dict[str, pl.Expr]: yaml_data = yaml.safe_load(yaml_content) return {key: parse_expression(value).alias(key) for key, value in yaml_data.items()} diff --git a/templates/07_c_kde.py b/templates/07_c_kde.py new file mode 100644 index 0000000..b3693a1 --- /dev/null +++ b/templates/07_c_kde.py @@ -0,0 +1,43 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import polars as pl + +from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel +from dont_fret.fileIO import PhotonFile +from dont_fret.models import PhotonData + +# %% +cwd = Path(__file__).parent +test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1" +ptu_file = "datafile_1.ptu" + +# %% +photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file)) + +tau = 50e-6 +assert photons.timestamps_unit +kernel = make_kernel(tau, photons.timestamps_unit) + +fig, ax = plt.subplots() +ax.plot(kernel) + +# %% + +# tomov et al eqn 9 +D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel) +A_ex = convolve_stream(photons.data, ["AA"], kernel) + +# %% + +kde_data = photons.data.select( + [pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")] +) + +bursts = photons.burst_search("APBS") + +# %% +alex_2cde_vals = alex_2cde(bursts.photon_data, kde_data) + +# %% diff --git a/tests/generate_fretbursts_kde.py b/tests/generate_fretbursts_kde.py new file mode 100644 index 0000000..e6dc9a1 --- /dev/null +++ b/tests/generate_fretbursts_kde.py @@ -0,0 +1,114 @@ +""" +generates test data for comparing fretbursts / dontfret channel kde + +ran with fretbursts '0.8.3' + +""" +# %% + +import polars as pl +from fretbursts import * +from fretbursts.phtools import phrates + +# %% + + +url = "http://files.figshare.com/2182601/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5" +download_file(url, save_dir="./data") + +filename = "data/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5" + +d = loader.photon_hdf5(filename) +loader.alex_apply_period(d) +d.calc_bg(fun=bg.exp_fit, time_s=20, tail_min_us="auto", F_bg=1.7) +d.burst_search() + +ds1 = d.select_bursts(select_bursts.size, th1=30) +ds = ds1.select_bursts(select_bursts.naa, th1=30) + +# %% +ph = d.ph_times_m[0] + + +tau_s = 50e-6 # in seconds +tau = int(tau_s / d.clk_p) # in raw timestamp units +tau + +# %% + +# %% +bursts = ds1.mburst[0] + +ph_dex = d.get_ph_times(ph_sel=Ph_sel(Dex="DAem")) +ph_aex = d.get_ph_times(ph_sel=Ph_sel(Aex="Aem")) + +mask_dex = d.get_ph_mask(ph_sel=Ph_sel(Dex="DAem")) +mask_aex = d.get_ph_mask(ph_sel=Ph_sel(Aex="Aem")) + +KDE_DexTi = phrates.kde_laplace(ph_dex, tau, time_axis=ph) +KDE_AexTi = phrates.kde_laplace(ph_aex, tau, time_axis=ph) + +# %% + + +# %% + +ALEX_2CDE = [] +BRDex, BRAex = [], [] +for ib, burst in enumerate(bursts): + burst_slice = slice(int(burst.istart), int(burst.istop) + 1) + if ~mask_dex[burst_slice].any() or ~mask_aex[burst_slice].any(): + # Either D or A photon stream has no photons in current burst, + # thus ALEX_2CDE cannot be computed. + ALEX_2CDE.append(np.nan) + continue + + kde_dexdex = KDE_DexTi[burst_slice][mask_dex[burst_slice]] + kde_aexdex = KDE_AexTi[burst_slice][mask_dex[burst_slice]] + N_chaex = mask_aex[burst_slice].sum() + BRDex.append(np.sum(kde_aexdex / kde_dexdex) / N_chaex) + + kde_aexaex = KDE_AexTi[burst_slice][mask_aex[burst_slice]] + kde_dexaex = KDE_DexTi[burst_slice][mask_aex[burst_slice]] + N_chdex = mask_dex[burst_slice].sum() + BRAex.append(np.sum(kde_dexaex / kde_aexaex) / N_chdex) + + alex_2cde = 100 - 50 * (BRDex[-1] - BRAex[-1]) + ALEX_2CDE.append(alex_2cde) +ALEX_2CDE = np.array(ALEX_2CDE) + +ALEX_2CDE + +# %% + +timestamps = ph +timestamps +# %% + +stream = np.empty_like(timestamps, dtype="U2") +stream + + +streams = { + "DD": {"Dex": "Dem"}, + "DA": {"Dex": "Aem"}, + "AA": {"Aex": "Aem"}, + "AD": {"Aex": "Dem"}, +} + +for stream_label, kwargs in streams.items(): + mask = d.get_ph_mask(ph_sel=Ph_sel(**kwargs)) + stream[mask] = stream_label + +# %% + + +photons_export = pl.DataFrame({"timestamps": timestamps, "stream": stream}) +photons_export.write_parquet("photon_data.pq") + +# %% + + +burst_export = pl.DataFrame([{"istart": b.istart, "istop": b.istop} for b in bursts]) +burst_export = burst_export.with_columns(pl.lit(ALEX_2CDE).alias("alex_2cde")) +burst_export.write_parquet("burst_data.pq") diff --git a/tests/test_c_kde.py b/tests/test_c_kde.py new file mode 100644 index 0000000..7300ba2 --- /dev/null +++ b/tests/test_c_kde.py @@ -0,0 +1,48 @@ +# %% + +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import polars as pl +import polars.testing as pl_test + +from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel +from dont_fret.fileIO import PhotonFile +from dont_fret.models import PhotonData + +cwd = Path(__file__).parent +input_data_dir = cwd / "test_data" / "input" +output_data_dir = cwd / "test_data" / "output" + + +def test_compare_cde_fretbursts(): + tau = 50e-6 + TIMESTAMPS_UNIT = 1.25e-08 + + photon_data = pl.read_parquet(output_data_dir / "kde" / "photon_data.pq") + bursts_ref = pl.read_parquet(output_data_dir / "kde" / "burst_data.pq") + + kernel = make_kernel(tau, TIMESTAMPS_UNIT) + D_ex = convolve_stream(photon_data, ["DD", "DA"], kernel) + A_ex = convolve_stream(photon_data, ["AA"], kernel) + kde_data = photon_data.select( + [ + pl.col("timestamps"), + pl.col("stream"), + pl.lit(D_ex).alias("D_ex"), + pl.lit(A_ex).alias("A_ex"), + ] + ) + + burst_dfs = [] + for i in range(len(bursts_ref)): + istart = bursts_ref[i]["istart"].item() + istop = bursts_ref[i]["istop"].item() + b_df = photon_data[istart : istop + 1].with_columns(pl.lit(i).alias("burst_index")) + burst_dfs.append(b_df) + + burst_data = pl.concat(burst_dfs, how="vertical_relaxed") + + alex_2cde_vals = alex_2cde(burst_data, kde_data) + pl_test.assert_series_equal(alex_2cde_vals.fill_null(np.nan), bursts_ref["alex_2cde"]) diff --git a/tests/test_data/output/kde/burst_data.pq b/tests/test_data/output/kde/burst_data.pq new file mode 100644 index 0000000..e7bad89 Binary files /dev/null and b/tests/test_data/output/kde/burst_data.pq differ diff --git a/tests/test_data/output/kde/photon_data.pq b/tests/test_data/output/kde/photon_data.pq new file mode 100644 index 0000000..139f93b Binary files /dev/null and b/tests/test_data/output/kde/photon_data.pq differ