Skip to content

Commit

Permalink
faster fretcde
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Nov 22, 2024
1 parent 7f21537 commit 058e1dd
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 19 deletions.
81 changes: 68 additions & 13 deletions dont_fret/channel_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,84 @@


def compute_fret_2cde(
burst_photons: pl.DataFrame,
kde_rates: pl.DataFrame,
# channels
) -> np.ndarray:
joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter(
pl.col("stream") == pl.col("stream_right")
)

f_DA = pl.col("stream") == "DA"
f_DD = pl.col("stream") == "DD"
df_DA = joined_df.filter(f_DA)
df_DD = joined_df.filter(f_DD)

matching_indices = np.intersect1d(df_DA["burst_index"].unique(), df_DD["burst_index"].unique())

DA_groups = {k: v for k, v in df_DA.group_by("burst_index")}
DD_groups = {k: v for k, v in df_DD.group_by("burst_index")}

N: int = burst_photons["burst_index"].max() + 1 # type: ignore
output = np.full(N, fill_value=np.nan)
for j in matching_indices:
df_DA_j = DA_groups[(j,)]
df_DD_j = DD_groups[(j,)]

kde_DA_DA = df_DA_j["DA"].to_numpy() # select DA density - kde^DA_DA
kde_DA_DD = df_DD_j["DA"].to_numpy() # kde^DA_DD (in the paper called kde^A_D)
kde_DD_DD = df_DD_j["DD"].to_numpy() # kde^DD_DD
kde_DD_DA = df_DA_j["DD"].to_numpy() # kde^DD_DA

try:
nbkde_DA_DA = (1 + 2 / len(kde_DA_DA)) * (kde_DA_DA - 1)
nbkde_DD_DD = (1 + 2 / len(kde_DD_DD)) * (kde_DD_DD - 1)

# when denom is zero, it doesnt count towards number of photons
# see "Such cases are removed by the computer algorithm", in Tomov et al.
denom = kde_DA_DD + nbkde_DD_DD
ED = (kde_DA_DD / denom).sum() / np.count_nonzero(denom) # when

denom = kde_DD_DA + nbkde_DA_DA
EA = (kde_DD_DA / denom).sum() / np.count_nonzero(denom) # = (1 - E)_A

fret_cde = 110 - 100 * (ED + EA)
output[j] = fret_cde
except ZeroDivisionError:
output[j] = np.nan
return output


def compute_fret_2cde_v1(
indices: pl.DataFrame,
kde_rates,
# TODO channel names
) -> np.ndarray:
output = np.empty(len(indices), dtype=float)

f_DA = pl.col("stream") == "DA"
f_DD = pl.col("stream") == "DD"

for idx, (imin, imax) in enumerate(indices.iter_rows()):
# df = kde_rates[imin:imax+1] # this is slower
df = kde_rates.slice(imin, imax - imin + 1)

# filtering like this is also slower
# df_f_DA = df.filter(f_DA) # select only the DA photons, thus these are for KDE^X_DA
# df_f_DD = df.filter(f_DD) # KDE^X_DD (density of ch X at DD timestamps)
df_f_DA = df.filter(f_DA) # select only the DA photons, thus these are for KDE^X_DA
df_f_DD = df.filter(f_DD) # KDE^X_DD (density of ch X at DD timestamps)

# kde_DA_DA = df_f_DA['DA'] # select DA density - kde^DA_DA
# kde_DA_DD = df_f_DD['DA'] # kde^DA_DD (in the paper called kde^A_D)
# kde_DD_DD = df_f_DD['DD'] # kde^DD_DD
# kde_DD_DA = df_f_DA['DD'] # kde^DD_DA
kde_DA_DA = df_f_DA["DA"] # select DA density - kde^DA_DA
kde_DA_DD = df_f_DD["DA"] # kde^DA_DD (in the paper called kde^A_D)
kde_DD_DD = df_f_DD["DD"] # kde^DD_DD
kde_DD_DA = df_f_DA["DD"] # kde^DD_DA

bools_DA = df["stream"] == "DA"
bools_DD = df["stream"] == "DD"
# bools_DA = df["stream"] == "DA"
# bools_DD = df["stream"] == "DD"

kde_DA_DA = df["DA"].filter(bools_DA) # select DA density - kde^DA_DA
kde_DA_DD = df["DA"].filter(bools_DD) # kde^DA_DD (in the paper called kde^A_D)
kde_DD_DD = df["DD"].filter(bools_DD) # kde^DD_DD
kde_DD_DA = df["DD"].filter(bools_DA) # kde^DD_DA
# kde_DA_DA = df["DA"].filter(bools_DA) # select DA density - kde^DA_DA
# kde_DA_DD = df["DA"].filter(bools_DD) # kde^DA_DD (in the paper called kde^A_D)
# kde_DD_DD = df["DD"].filter(bools_DD) # kde^DD_DD
# kde_DD_DA = df["DD"].filter(bools_DA) # kde^DD_DA

try:
nbkde_DA_DA = (1 + 2 / len(kde_DA_DA)) * (kde_DA_DA - 1)
Expand Down Expand Up @@ -122,7 +176,7 @@ def compute_alex_2cde(
# tomov et al eqn 12
# this is an addition in the paper
# in fretbursts its subtracted
ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col(
ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") + (1 / pl.col("N_dex")) * pl.col(
"ratio_DA"
)

Expand Down Expand Up @@ -157,6 +211,7 @@ def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray)

df = data.filter(f_expr)
# TODO warn on copy
# dataframes read from .pq cannot be converted zero-copy
event_times = df["timestamps"].to_numpy(allow_copy=True)
eval_times = data["timestamps"].to_numpy(allow_copy=True)

Expand Down
19 changes: 13 additions & 6 deletions templates/07_c_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import polars as pl

from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel
from dont_fret.expr import is_in_expr
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData

Expand Down Expand Up @@ -36,24 +35,32 @@
kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")]
)

# %%
alex_2cde = compute_alex_2cde(bursts.photon_data, kde_data)

# %%
DA = convolve_stream(photons.data, ["DA"], kernel)
DD = convolve_stream(photons.data, ["DD"], kernel)
kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(DA).alias("DA"), pl.lit(DD).alias("DD")]
)
kde_rates = kde_data

# %%
# 500 ms
fret_2cde = compute_fret_2cde(bursts.photon_data, kde_data)

fret_2cde = compute_fret_2cde(bursts.indices, kde_data)

fret_2cde
# %%
burst_photons = bursts.photon_data
joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter(
pl.col("stream") == pl.col("stream_right")
)

# %%

fig, axes = plt.subplots(ncols=2)
axes[0].hist(alex_2cde, bins="fd")
axes[1].hist(fret_2cde, bins="fd")
h = axes[0].hist(alex_2cde, bins="fd")
h = axes[1].hist(fret_2cde, bins="fd")
# axes[1].axvline(10, color='r')
# %%

0 comments on commit 058e1dd

Please sign in to comment.