From 7f2153762adb50945ba7238335c34d8badd768cf Mon Sep 17 00:00:00 2001 From: Jochem Smit Date: Tue, 12 Nov 2024 22:23:11 +0100 Subject: [PATCH] fret cde --- dont_fret/channel_kde.py | 80 +++++++++++++++++++++++++++++++++++----- dont_fret/models.py | 12 +++++- templates/07_c_kde.py | 28 +++++++++++--- tests/test_c_kde.py | 4 +- 4 files changed, 105 insertions(+), 19 deletions(-) diff --git a/dont_fret/channel_kde.py b/dont_fret/channel_kde.py index fac0a3a..17da5a8 100644 --- a/dont_fret/channel_kde.py +++ b/dont_fret/channel_kde.py @@ -19,7 +19,57 @@ from dont_fret.expr import is_in_expr -def alex_2cde( +def compute_fret_2cde( + indices: pl.DataFrame, + kde_rates, +) -> np.ndarray: + output = np.empty(len(indices), dtype=float) + for idx, (imin, imax) in enumerate(indices.iter_rows()): + # df = kde_rates[imin:imax+1] # this is slower + df = kde_rates.slice(imin, imax - imin + 1) + + # filtering like this is also slower + # df_f_DA = df.filter(f_DA) # select only the DA photons, thus these are for KDE^X_DA + # df_f_DD = df.filter(f_DD) # KDE^X_DD (density of ch X at DD timestamps) + + # kde_DA_DA = df_f_DA['DA'] # select DA density - kde^DA_DA + # kde_DA_DD = df_f_DD['DA'] # kde^DA_DD (in the paper called kde^A_D) + # kde_DD_DD = df_f_DD['DD'] # kde^DD_DD + # kde_DD_DA = df_f_DA['DD'] # kde^DD_DA + + bools_DA = df["stream"] == "DA" + bools_DD = df["stream"] == "DD" + + kde_DA_DA = df["DA"].filter(bools_DA) # select DA density - kde^DA_DA + kde_DA_DD = df["DA"].filter(bools_DD) # kde^DA_DD (in the paper called kde^A_D) + kde_DD_DD = df["DD"].filter(bools_DD) # kde^DD_DD + kde_DD_DA = df["DD"].filter(bools_DA) # kde^DD_DA + + try: + nbkde_DA_DA = (1 + 2 / len(kde_DA_DA)) * (kde_DA_DA - 1) + nbkde_DD_DD = (1 + 2 / len(kde_DD_DD)) * (kde_DD_DD - 1) + + # ED = (kde_DA_DD / (kde_DA_DD + nbkde_DD_DD)).sum() / nbkde_DD_DD.count() + # EA = (kde_DD_DA / (kde_DD_DA + nbkde_DA_DA)).sum() / nbkde_DA_DA.count() # = (1 - E)_A + + # when denom is zero, it doesnt count towards number of photons + # see "Such cases are removed by the computer algorithm", in Tomov et al. + denom = kde_DA_DD + nbkde_DD_DD + ED = (kde_DA_DD / denom).sum() / (denom != 0.0).sum() # when + + denom = kde_DD_DA + nbkde_DA_DA + EA = (kde_DD_DA / denom).sum() / (denom != 0.0).sum() # = (1 - E)_A + + fret_cde = 110 - 100 * (ED + EA) + output[idx] = fret_cde + except ZeroDivisionError: + output[idx] = np.nan + + return output + + +# refactor to indices / loop +def compute_alex_2cde( burst_photons: pl.DataFrame, kde_rates: pl.DataFrame, dex_streams: Optional[list[str]] = None, @@ -39,11 +89,14 @@ def alex_2cde( # equivalent to (but faster): # joined_df = burst_photons.join(kde_rates, on=['timestamps', 'stream'], how='inner') - j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left") - joined_df = j1.filter(pl.col("stream") == pl.col("stream_right")) - # joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter( - # pl.col("stream") == pl.col("stream_right") - # ) + # still, this is the slowest step. would be nice if we can improve since we know the indices + + # remove comments after passing test: + # j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left") + # joined_df = j1.filter(pl.col("stream") == pl.col("stream_right")) + joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter( + pl.col("stream") == pl.col("stream_right") + ) b_df = joined_df.select( [ @@ -67,9 +120,12 @@ def alex_2cde( combined = pl.concat([agg_dex, agg_dax], how="align") # tomov et al eqn 12 + # this is an addition in the paper + # in fretbursts its subtracted ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col( "ratio_DA" ) + ax_2cde_norm = pl.lit(100) - pl.lit(50) * ax_2cde_bracket alex_2cde = combined.select(ax_2cde_norm.alias("alex_2cde")).to_series() @@ -104,6 +160,9 @@ def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray) event_times = df["timestamps"].to_numpy(allow_copy=True) eval_times = data["timestamps"].to_numpy(allow_copy=True) + # event_times = np.array(df["timestamps"]) + # eval_times = np.array(data["timestamps"]) + return async_convolve(event_times, eval_times, kernel) @@ -113,8 +172,7 @@ def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray) types.Array(int64, 1, "C", readonly=True), types.Array(float64, 1, "C", readonly=True), ), - nopython=True, - cache=True, + nopython=False, nogil=True, ) def async_convolve(event_times, eval_times, kernel): @@ -128,12 +186,16 @@ def async_convolve(event_times, eval_times, kernel): for i, time in enumerate(eval_times): while event_times[i_upper] < time + window_half_size: + i_upper += 1 if i_upper == len(event_times): + i_upper -= 1 break - i_upper += 1 while event_times[i_lower] < time - window_half_size: i_lower += 1 + if i_lower == len(event_times): + i_lower -= 1 + break for j in range(i_lower, i_upper): idx = event_times[j] - time + window_half_size diff --git a/dont_fret/models.py b/dont_fret/models.py index 418d1d4..b9f13cb 100644 --- a/dont_fret/models.py +++ b/dont_fret/models.py @@ -250,18 +250,21 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: # Check if any of the times _items is empty, if so, bursts is empty if any(len(t) == 0 for t in times_list): burst_photons = pl.DataFrame({k: [] for k in self.data.columns + ["burst_index"]}) + indices = pl.DataFrame({"imin": [], "imax": []}) else: # Take the intersection of the time intervals found by the multi-color burst search final_times = reduce(return_intersections, times_list) if len(final_times) == 0: # No overlap found burst_photons = pl.DataFrame({k: [] for k in self.data.columns + ["burst_index"]}) + indices = pl.DataFrame({"imin": [], "imax": []}) else: tmin, tmax = np.array(final_times).T # Convert back to indices imin = np.searchsorted(self.timestamps, tmin) imax = np.searchsorted(self.timestamps, tmax) + indices = pl.DataFrame({"imin": imin, "imax": imax}) # take all photons (up to and including? edges need to be checked!) b_num = int(2 ** np.ceil(np.log2((np.log2(len(imin)))))) dtype = getattr(pl, f"UInt{b_num}", pl.Int32) @@ -271,7 +274,7 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: ] burst_photons = pl.concat(bursts) - bs = Bursts(burst_photons, metadata=self.metadata) + bs = Bursts(burst_photons, indices=indices, metadata=self.metadata) return bs @@ -364,9 +367,14 @@ class Bursts(object): # todo add metadata support # bursts: numpy.typing.ArrayLike[Bursts] ? def __init__( - self, photon_data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = cfg + self, + photon_data: pl.DataFrame, + indices: pl.DataFrame, + metadata: Optional[dict] = None, + cfg: DontFRETConfig = cfg, ): self.photon_data = photon_data + self.indices = indices self.metadata: dict = metadata or {} self.cfg = cfg diff --git a/templates/07_c_kde.py b/templates/07_c_kde.py index b3693a1..c3559ea 100644 --- a/templates/07_c_kde.py +++ b/templates/07_c_kde.py @@ -4,17 +4,20 @@ import matplotlib.pyplot as plt import polars as pl -from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel +from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel +from dont_fret.expr import is_in_expr from dont_fret.fileIO import PhotonFile from dont_fret.models import PhotonData # %% cwd = Path(__file__).parent -test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1" +test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds3" ptu_file = "datafile_1.ptu" +ptu_file = "210122_sFRET_MBP5K_apo_15.ptu" # %% photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file)) +bursts = photons.burst_search("APBS") tau = 50e-6 assert photons.timestamps_unit @@ -24,20 +27,33 @@ ax.plot(kernel) # %% - # tomov et al eqn 9 D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel) -A_ex = convolve_stream(photons.data, ["AA"], kernel) +A_ex = convolve_stream(photons.data, ["AA"], kernel) # crashed the kernel (sometimes) # %% kde_data = photons.data.select( [pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")] ) +alex_2cde = compute_alex_2cde(bursts.photon_data, kde_data) -bursts = photons.burst_search("APBS") +DA = convolve_stream(photons.data, ["DA"], kernel) +DD = convolve_stream(photons.data, ["DD"], kernel) +kde_data = photons.data.select( + [pl.col("timestamps"), pl.col("stream"), pl.lit(DA).alias("DA"), pl.lit(DD).alias("DD")] +) + +# %% + +fret_2cde = compute_fret_2cde(bursts.indices, kde_data) + +fret_2cde # %% -alex_2cde_vals = alex_2cde(bursts.photon_data, kde_data) +fig, axes = plt.subplots(ncols=2) +axes[0].hist(alex_2cde, bins="fd") +axes[1].hist(fret_2cde, bins="fd") +# axes[1].axvline(10, color='r') # %% diff --git a/tests/test_c_kde.py b/tests/test_c_kde.py index 7300ba2..3445c5b 100644 --- a/tests/test_c_kde.py +++ b/tests/test_c_kde.py @@ -7,7 +7,7 @@ import polars as pl import polars.testing as pl_test -from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel +from dont_fret.channel_kde import compute_alex_2cde, convolve_stream, make_kernel from dont_fret.fileIO import PhotonFile from dont_fret.models import PhotonData @@ -44,5 +44,5 @@ def test_compare_cde_fretbursts(): burst_data = pl.concat(burst_dfs, how="vertical_relaxed") - alex_2cde_vals = alex_2cde(burst_data, kde_data) + alex_2cde_vals = compute_alex_2cde(burst_data, kde_data) pl_test.assert_series_equal(alex_2cde_vals.fill_null(np.nan), bursts_ref["alex_2cde"])