Skip to content

Commit

Permalink
add alex cde
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Nov 12, 2024
1 parent 0163d6a commit d9e3f63
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 0 deletions.
142 changes: 142 additions & 0 deletions dont_fret/channel_kde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Module for channel based kernel density estimation
This is based on:
Disentangling Subpopulations in Single-Molecule FRET and ALEX Experiments with Photon Distribution Analysis
https://doi.org/10.1016/j.bpj.2011.11.4025
Please cite the paper when using this module.
"""

import warnings
from typing import Optional

import numpy as np
import polars as pl
from numba import float64, int64, jit, types

from dont_fret.expr import is_in_expr


def alex_2cde(
burst_photons: pl.DataFrame,
kde_rates: pl.DataFrame,
dex_streams: Optional[list[str]] = None,
aex_streams: Optional[list[str]] = None,
) -> pl.Series:
"""
burst_photons: Dataframe with columns: timestamps, stream, burst_index
kde_rates: Dataframe with columns: timestamps, D_ex, A_ex.
dex_streams: list of photon streams which are donor excitation (default: DD, DA)
aex_streams: list of photon streams which are acceptor excitation (default: AA)
"""
dex_streams = dex_streams if dex_streams else ["DD", "DA"]
aex_streams = aex_streams if aex_streams else ["AA"]

f_dex = is_in_expr("stream", dex_streams)
f_aex = is_in_expr("stream", aex_streams)

# equivalent to (but faster):
# joined_df = burst_photons.join(kde_rates, on=['timestamps', 'stream'], how='inner')
j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left")
joined_df = j1.filter(pl.col("stream") == pl.col("stream_right"))
# joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter(
# pl.col("stream") == pl.col("stream_right")
# )

b_df = joined_df.select(
[
pl.col("burst_index"),
pl.col("stream"),
(pl.col("A_ex") / pl.col("D_ex")).alias("ratio_AD"),
]
)

# tomov et al eqn 10 and 11
df_f_dex = b_df.filter(f_dex)
agg_dex = df_f_dex.group_by("burst_index", maintain_order=True).agg(
[pl.col("ratio_AD").sum(), pl.len().alias("N_dex")]
)

df_f_aex = b_df.filter(f_aex)
agg_dax = df_f_aex.group_by("burst_index", maintain_order=True).agg(
[(1 / pl.col("ratio_AD")).sum().alias("ratio_DA"), pl.len().alias("N_aex")]
)

combined = pl.concat([agg_dex, agg_dax], how="align")

# tomov et al eqn 12
ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col(
"ratio_DA"
)
ax_2cde_norm = pl.lit(100) - pl.lit(50) * ax_2cde_bracket

alex_2cde = combined.select(ax_2cde_norm.alias("alex_2cde")).to_series()

return alex_2cde


def make_kernel(
tau: float, timestamps_unit: float, domain_size: int = 10, kernel="laplace"
) -> np.ndarray:
window_size = domain_size * (tau / timestamps_unit)
window_size_even_int = 2 * round(window_size / 2)

# check that rounding error isnt too large
rel_dev = (window_size - window_size_even_int) / window_size
if np.abs(rel_dev) > 0.01:
warnings.warn(
"Kernel window size deviation from rounding larger than 1 percent. Choose a smaller `tau` with respect to `timestamps_unit'"
)

t_eval = np.linspace(-domain_size / 2, domain_size / 2, window_size_even_int + 1, endpoint=True)
kernel = np.exp(-np.abs(t_eval))

return kernel


def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray) -> np.ndarray:
f_expr = is_in_expr("stream", streams)

df = data.filter(f_expr)
# TODO warn on copy
event_times = df["timestamps"].to_numpy(allow_copy=True)
eval_times = data["timestamps"].to_numpy(allow_copy=True)

return async_convolve(event_times, eval_times, kernel)


@jit(
float64[:](
types.Array(int64, 1, "C", readonly=True),
types.Array(int64, 1, "C", readonly=True),
types.Array(float64, 1, "C", readonly=True),
),
nopython=True,
cache=True,
nogil=True,
)
def async_convolve(event_times, eval_times, kernel):
"""convolve integer timestamps with a kernel"""

i_lower = 0
i_upper = 0

window_half_size = len(kernel) // 2
result = np.zeros_like(eval_times, dtype=np.float64)

for i, time in enumerate(eval_times):
while event_times[i_upper] < time + window_half_size:
if i_upper == len(event_times):
break
i_upper += 1

while event_times[i_lower] < time - window_half_size:
i_lower += 1

for j in range(i_lower, i_upper):
idx = event_times[j] - time + window_half_size
result[i] += kernel[idx]

return result
11 changes: 11 additions & 0 deletions dont_fret/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ def reduce_or(exprs: list[pl.Expr]) -> pl.Expr:
return reduce(or_, exprs)


def is_in_expr(field: str, values: list) -> pl.Expr:
"""generate an polars expression equivalent to pl.col(field).is_in(values) by chaining equal
and or operations together
this is (sometimes?) faster to execute
"""

exprs = [pl.col(field) == value for value in values]
return reduce_or(exprs)


def parse_yaml_expressions(yaml_content: str) -> Dict[str, pl.Expr]:
yaml_data = yaml.safe_load(yaml_content)
return {key: parse_expression(value).alias(key) for key, value in yaml_data.items()}
Expand Down
43 changes: 43 additions & 0 deletions templates/07_c_kde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# %%
from pathlib import Path

import matplotlib.pyplot as plt
import polars as pl

from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData

# %%
cwd = Path(__file__).parent
test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1"
ptu_file = "datafile_1.ptu"

# %%
photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file))

tau = 50e-6
assert photons.timestamps_unit
kernel = make_kernel(tau, photons.timestamps_unit)

fig, ax = plt.subplots()
ax.plot(kernel)

# %%

# tomov et al eqn 9
D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel)
A_ex = convolve_stream(photons.data, ["AA"], kernel)

# %%

kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")]
)

bursts = photons.burst_search("APBS")

# %%
alex_2cde_vals = alex_2cde(bursts.photon_data, kde_data)

# %%
114 changes: 114 additions & 0 deletions tests/generate_fretbursts_kde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
generates test data for comparing fretbursts / dontfret channel kde
ran with fretbursts '0.8.3'
"""
# %%

import polars as pl
from fretbursts import *
from fretbursts.phtools import phrates

# %%


url = "http://files.figshare.com/2182601/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5"
download_file(url, save_dir="./data")

filename = "data/0023uLRpitc_NTP_20dT_0.5GndCl.hdf5"

d = loader.photon_hdf5(filename)
loader.alex_apply_period(d)
d.calc_bg(fun=bg.exp_fit, time_s=20, tail_min_us="auto", F_bg=1.7)
d.burst_search()

ds1 = d.select_bursts(select_bursts.size, th1=30)
ds = ds1.select_bursts(select_bursts.naa, th1=30)

# %%
ph = d.ph_times_m[0]


tau_s = 50e-6 # in seconds
tau = int(tau_s / d.clk_p) # in raw timestamp units
tau

# %%

# %%
bursts = ds1.mburst[0]

ph_dex = d.get_ph_times(ph_sel=Ph_sel(Dex="DAem"))
ph_aex = d.get_ph_times(ph_sel=Ph_sel(Aex="Aem"))

mask_dex = d.get_ph_mask(ph_sel=Ph_sel(Dex="DAem"))
mask_aex = d.get_ph_mask(ph_sel=Ph_sel(Aex="Aem"))

KDE_DexTi = phrates.kde_laplace(ph_dex, tau, time_axis=ph)
KDE_AexTi = phrates.kde_laplace(ph_aex, tau, time_axis=ph)

# %%


# %%

ALEX_2CDE = []
BRDex, BRAex = [], []
for ib, burst in enumerate(bursts):
burst_slice = slice(int(burst.istart), int(burst.istop) + 1)
if ~mask_dex[burst_slice].any() or ~mask_aex[burst_slice].any():
# Either D or A photon stream has no photons in current burst,
# thus ALEX_2CDE cannot be computed.
ALEX_2CDE.append(np.nan)
continue

kde_dexdex = KDE_DexTi[burst_slice][mask_dex[burst_slice]]
kde_aexdex = KDE_AexTi[burst_slice][mask_dex[burst_slice]]
N_chaex = mask_aex[burst_slice].sum()
BRDex.append(np.sum(kde_aexdex / kde_dexdex) / N_chaex)

kde_aexaex = KDE_AexTi[burst_slice][mask_aex[burst_slice]]
kde_dexaex = KDE_DexTi[burst_slice][mask_aex[burst_slice]]
N_chdex = mask_dex[burst_slice].sum()
BRAex.append(np.sum(kde_dexaex / kde_aexaex) / N_chdex)

alex_2cde = 100 - 50 * (BRDex[-1] - BRAex[-1])
ALEX_2CDE.append(alex_2cde)
ALEX_2CDE = np.array(ALEX_2CDE)

ALEX_2CDE

# %%

timestamps = ph
timestamps
# %%

stream = np.empty_like(timestamps, dtype="U2")
stream


streams = {
"DD": {"Dex": "Dem"},
"DA": {"Dex": "Aem"},
"AA": {"Aex": "Aem"},
"AD": {"Aex": "Dem"},
}

for stream_label, kwargs in streams.items():
mask = d.get_ph_mask(ph_sel=Ph_sel(**kwargs))
stream[mask] = stream_label

# %%


photons_export = pl.DataFrame({"timestamps": timestamps, "stream": stream})
photons_export.write_parquet("photon_data.pq")

# %%


burst_export = pl.DataFrame([{"istart": b.istart, "istop": b.istop} for b in bursts])
burst_export = burst_export.with_columns(pl.lit(ALEX_2CDE).alias("alex_2cde"))
burst_export.write_parquet("burst_data.pq")
48 changes: 48 additions & 0 deletions tests/test_c_kde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# %%

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import polars.testing as pl_test

from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData

cwd = Path(__file__).parent
input_data_dir = cwd / "test_data" / "input"
output_data_dir = cwd / "test_data" / "output"


def test_compare_cde_fretbursts():
tau = 50e-6
TIMESTAMPS_UNIT = 1.25e-08

photon_data = pl.read_parquet(output_data_dir / "kde" / "photon_data.pq")
bursts_ref = pl.read_parquet(output_data_dir / "kde" / "burst_data.pq")

kernel = make_kernel(tau, TIMESTAMPS_UNIT)
D_ex = convolve_stream(photon_data, ["DD", "DA"], kernel)
A_ex = convolve_stream(photon_data, ["AA"], kernel)
kde_data = photon_data.select(
[
pl.col("timestamps"),
pl.col("stream"),
pl.lit(D_ex).alias("D_ex"),
pl.lit(A_ex).alias("A_ex"),
]
)

burst_dfs = []
for i in range(len(bursts_ref)):
istart = bursts_ref[i]["istart"].item()
istop = bursts_ref[i]["istop"].item()
b_df = photon_data[istart : istop + 1].with_columns(pl.lit(i).alias("burst_index"))
burst_dfs.append(b_df)

burst_data = pl.concat(burst_dfs, how="vertical_relaxed")

alex_2cde_vals = alex_2cde(burst_data, kde_data)
pl_test.assert_series_equal(alex_2cde_vals.fill_null(np.nan), bursts_ref["alex_2cde"])
Binary file added tests/test_data/output/kde/burst_data.pq
Binary file not shown.
Binary file added tests/test_data/output/kde/photon_data.pq
Binary file not shown.

0 comments on commit d9e3f63

Please sign in to comment.