-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
358 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
""" | ||
Module for channel based kernel density estimation | ||
This is based on: | ||
Disentangling Subpopulations in Single-Molecule FRET and ALEX Experiments with Photon Distribution Analysis | ||
https://doi.org/10.1016/j.bpj.2011.11.4025 | ||
Please cite the paper when using this module. | ||
""" | ||
|
||
import warnings | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import polars as pl | ||
from numba import float64, int64, jit, types | ||
|
||
from dont_fret.expr import is_in_expr | ||
|
||
|
||
def alex_2cde( | ||
burst_photons: pl.DataFrame, | ||
kde_rates: pl.DataFrame, | ||
dex_streams: Optional[list[str]] = None, | ||
aex_streams: Optional[list[str]] = None, | ||
) -> pl.Series: | ||
""" | ||
burst_photons: Dataframe with columns: timestamps, stream, burst_index | ||
kde_rates: Dataframe with columns: timestamps, D_ex, A_ex. | ||
dex_streams: list of photon streams which are donor excitation (default: DD, DA) | ||
aex_streams: list of photon streams which are acceptor excitation (default: AA) | ||
""" | ||
dex_streams = dex_streams if dex_streams else ["DD", "DA"] | ||
aex_streams = aex_streams if aex_streams else ["AA"] | ||
|
||
f_dex = is_in_expr("stream", dex_streams) | ||
f_aex = is_in_expr("stream", aex_streams) | ||
|
||
# equivalent to (but faster): | ||
# joined_df = burst_photons.join(kde_rates, on=['timestamps', 'stream'], how='inner') | ||
j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left") | ||
joined_df = j1.filter(pl.col("stream") == pl.col("stream_right")) | ||
# joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter( | ||
# pl.col("stream") == pl.col("stream_right") | ||
# ) | ||
|
||
b_df = joined_df.select( | ||
[ | ||
pl.col("burst_index"), | ||
pl.col("stream"), | ||
(pl.col("A_ex") / pl.col("D_ex")).alias("ratio_AD"), | ||
] | ||
) | ||
|
||
# tomov et al eqn 10 and 11 | ||
df_f_dex = b_df.filter(f_dex) | ||
agg_dex = df_f_dex.group_by("burst_index", maintain_order=True).agg( | ||
[pl.col("ratio_AD").sum(), pl.len().alias("N_dex")] | ||
) | ||
|
||
df_f_aex = b_df.filter(f_aex) | ||
agg_dax = df_f_aex.group_by("burst_index", maintain_order=True).agg( | ||
[(1 / pl.col("ratio_AD")).sum().alias("ratio_DA"), pl.len().alias("N_aex")] | ||
) | ||
|
||
combined = pl.concat([agg_dex, agg_dax], how="align") | ||
|
||
# tomov et al eqn 12 | ||
ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col( | ||
"ratio_DA" | ||
) | ||
ax_2cde_norm = pl.lit(100) - pl.lit(50) * ax_2cde_bracket | ||
|
||
alex_2cde = combined.select(ax_2cde_norm.alias("alex_2cde")).to_series() | ||
|
||
return alex_2cde | ||
|
||
|
||
def make_kernel( | ||
tau: float, timestamps_unit: float, domain_size: int = 10, kernel="laplace" | ||
) -> np.ndarray: | ||
window_size = domain_size * (tau / timestamps_unit) | ||
window_size_even_int = 2 * round(window_size / 2) | ||
|
||
# check that rounding error isnt too large | ||
rel_dev = (window_size - window_size_even_int) / window_size | ||
if np.abs(rel_dev) > 0.01: | ||
warnings.warn( | ||
"Kernel window size deviation from rounding larger than 1 percent. Choose a smaller `tau` with respect to `timestamps_unit'" | ||
) | ||
|
||
t_eval = np.linspace(-domain_size / 2, domain_size / 2, window_size_even_int + 1, endpoint=True) | ||
kernel = np.exp(-np.abs(t_eval)) | ||
|
||
return kernel | ||
|
||
|
||
def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray) -> np.ndarray: | ||
f_expr = is_in_expr("stream", streams) | ||
|
||
df = data.filter(f_expr) | ||
# TODO warn on copy | ||
event_times = df["timestamps"].to_numpy(allow_copy=True) | ||
eval_times = data["timestamps"].to_numpy(allow_copy=True) | ||
|
||
return async_convolve(event_times, eval_times, kernel) | ||
|
||
|
||
@jit( | ||
float64[:]( | ||
types.Array(int64, 1, "C", readonly=True), | ||
types.Array(int64, 1, "C", readonly=True), | ||
types.Array(float64, 1, "C", readonly=True), | ||
), | ||
nopython=True, | ||
cache=True, | ||
nogil=True, | ||
) | ||
def async_convolve(event_times, eval_times, kernel): | ||
"""convolve integer timestamps with a kernel""" | ||
|
||
i_lower = 0 | ||
i_upper = 0 | ||
|
||
window_half_size = len(kernel) // 2 | ||
result = np.zeros_like(eval_times, dtype=np.float64) | ||
|
||
for i, time in enumerate(eval_times): | ||
while event_times[i_upper] < time + window_half_size: | ||
if i_upper == len(event_times): | ||
break | ||
i_upper += 1 | ||
|
||
while event_times[i_lower] < time - window_half_size: | ||
i_lower += 1 | ||
|
||
for j in range(i_lower, i_upper): | ||
idx = event_times[j] - time + window_half_size | ||
result[i] += kernel[idx] | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# %% | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import polars as pl | ||
|
||
from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel | ||
from dont_fret.fileIO import PhotonFile | ||
from dont_fret.models import PhotonData | ||
|
||
# %% | ||
cwd = Path(__file__).parent | ||
test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1" | ||
ptu_file = "datafile_1.ptu" | ||
|
||
# %% | ||
photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file)) | ||
|
||
tau = 50e-6 | ||
assert photons.timestamps_unit | ||
kernel = make_kernel(tau, photons.timestamps_unit) | ||
|
||
fig, ax = plt.subplots() | ||
ax.plot(kernel) | ||
|
||
# %% | ||
|
||
# tomov et al eqn 9 | ||
D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel) | ||
A_ex = convolve_stream(photons.data, ["AA"], kernel) | ||
|
||
# %% | ||
|
||
kde_data = photons.data.select( | ||
[pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")] | ||
) | ||
|
||
bursts = photons.burst_search("APBS") | ||
|
||
# %% | ||
alex_2cde_vals = alex_2cde(bursts.photon_data, kde_data) | ||
|
||
# %% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
generates test data for comparing fretbursts / dontfret channel kde | ||
ran with fretbursts '0.8.3' | ||
""" | ||
# %% | ||
|
||
import polars as pl | ||
from fretbursts import * | ||
from fretbursts.phtools import phrates | ||
|
||
# %% | ||
|
||
|
||
url = "http://files.figshare.com/2182601/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5" | ||
download_file(url, save_dir="./data") | ||
|
||
filename = "data/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5" | ||
|
||
d = loader.photon_hdf5(filename) | ||
loader.alex_apply_period(d) | ||
d.calc_bg(fun=bg.exp_fit, time_s=20, tail_min_us="auto", F_bg=1.7) | ||
d.burst_search() | ||
|
||
ds1 = d.select_bursts(select_bursts.size, th1=30) | ||
ds = ds1.select_bursts(select_bursts.naa, th1=30) | ||
|
||
# %% | ||
ph = d.ph_times_m[0] | ||
|
||
|
||
tau_s = 50e-6 # in seconds | ||
tau = int(tau_s / d.clk_p) # in raw timestamp units | ||
tau | ||
|
||
# %% | ||
|
||
# %% | ||
bursts = ds1.mburst[0] | ||
|
||
ph_dex = d.get_ph_times(ph_sel=Ph_sel(Dex="DAem")) | ||
ph_aex = d.get_ph_times(ph_sel=Ph_sel(Aex="Aem")) | ||
|
||
mask_dex = d.get_ph_mask(ph_sel=Ph_sel(Dex="DAem")) | ||
mask_aex = d.get_ph_mask(ph_sel=Ph_sel(Aex="Aem")) | ||
|
||
KDE_DexTi = phrates.kde_laplace(ph_dex, tau, time_axis=ph) | ||
KDE_AexTi = phrates.kde_laplace(ph_aex, tau, time_axis=ph) | ||
|
||
# %% | ||
|
||
|
||
# %% | ||
|
||
ALEX_2CDE = [] | ||
BRDex, BRAex = [], [] | ||
for ib, burst in enumerate(bursts): | ||
burst_slice = slice(int(burst.istart), int(burst.istop) + 1) | ||
if ~mask_dex[burst_slice].any() or ~mask_aex[burst_slice].any(): | ||
# Either D or A photon stream has no photons in current burst, | ||
# thus ALEX_2CDE cannot be computed. | ||
ALEX_2CDE.append(np.nan) | ||
continue | ||
|
||
kde_dexdex = KDE_DexTi[burst_slice][mask_dex[burst_slice]] | ||
kde_aexdex = KDE_AexTi[burst_slice][mask_dex[burst_slice]] | ||
N_chaex = mask_aex[burst_slice].sum() | ||
BRDex.append(np.sum(kde_aexdex / kde_dexdex) / N_chaex) | ||
|
||
kde_aexaex = KDE_AexTi[burst_slice][mask_aex[burst_slice]] | ||
kde_dexaex = KDE_DexTi[burst_slice][mask_aex[burst_slice]] | ||
N_chdex = mask_dex[burst_slice].sum() | ||
BRAex.append(np.sum(kde_dexaex / kde_aexaex) / N_chdex) | ||
|
||
alex_2cde = 100 - 50 * (BRDex[-1] - BRAex[-1]) | ||
ALEX_2CDE.append(alex_2cde) | ||
ALEX_2CDE = np.array(ALEX_2CDE) | ||
|
||
ALEX_2CDE | ||
|
||
# %% | ||
|
||
timestamps = ph | ||
timestamps | ||
# %% | ||
|
||
stream = np.empty_like(timestamps, dtype="U2") | ||
stream | ||
|
||
|
||
streams = { | ||
"DD": {"Dex": "Dem"}, | ||
"DA": {"Dex": "Aem"}, | ||
"AA": {"Aex": "Aem"}, | ||
"AD": {"Aex": "Dem"}, | ||
} | ||
|
||
for stream_label, kwargs in streams.items(): | ||
mask = d.get_ph_mask(ph_sel=Ph_sel(**kwargs)) | ||
stream[mask] = stream_label | ||
|
||
# %% | ||
|
||
|
||
photons_export = pl.DataFrame({"timestamps": timestamps, "stream": stream}) | ||
photons_export.write_parquet("photon_data.pq") | ||
|
||
# %% | ||
|
||
|
||
burst_export = pl.DataFrame([{"istart": b.istart, "istop": b.istop} for b in bursts]) | ||
burst_export = burst_export.with_columns(pl.lit(ALEX_2CDE).alias("alex_2cde")) | ||
burst_export.write_parquet("burst_data.pq") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# %% | ||
|
||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import polars as pl | ||
import polars.testing as pl_test | ||
|
||
from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel | ||
from dont_fret.fileIO import PhotonFile | ||
from dont_fret.models import PhotonData | ||
|
||
cwd = Path(__file__).parent | ||
input_data_dir = cwd / "test_data" / "input" | ||
output_data_dir = cwd / "test_data" / "output" | ||
|
||
|
||
def test_compare_cde_fretbursts(): | ||
tau = 50e-6 | ||
TIMESTAMPS_UNIT = 1.25e-08 | ||
|
||
photon_data = pl.read_parquet(output_data_dir / "kde" / "photon_data.pq") | ||
bursts_ref = pl.read_parquet(output_data_dir / "kde" / "burst_data.pq") | ||
|
||
kernel = make_kernel(tau, TIMESTAMPS_UNIT) | ||
D_ex = convolve_stream(photon_data, ["DD", "DA"], kernel) | ||
A_ex = convolve_stream(photon_data, ["AA"], kernel) | ||
kde_data = photon_data.select( | ||
[ | ||
pl.col("timestamps"), | ||
pl.col("stream"), | ||
pl.lit(D_ex).alias("D_ex"), | ||
pl.lit(A_ex).alias("A_ex"), | ||
] | ||
) | ||
|
||
burst_dfs = [] | ||
for i in range(len(bursts_ref)): | ||
istart = bursts_ref[i]["istart"].item() | ||
istop = bursts_ref[i]["istop"].item() | ||
b_df = photon_data[istart : istop + 1].with_columns(pl.lit(i).alias("burst_index")) | ||
burst_dfs.append(b_df) | ||
|
||
burst_data = pl.concat(burst_dfs, how="vertical_relaxed") | ||
|
||
alex_2cde_vals = alex_2cde(burst_data, kde_data) | ||
pl_test.assert_series_equal(alex_2cde_vals.fill_null(np.nan), bursts_ref["alex_2cde"]) |
Binary file not shown.
Binary file not shown.