Skip to content

Commit

Permalink
fret cde
Browse files Browse the repository at this point in the history
  • Loading branch information
Jhsmit committed Nov 12, 2024
1 parent d9e3f63 commit 7f21537
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 19 deletions.
80 changes: 71 additions & 9 deletions dont_fret/channel_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,57 @@
from dont_fret.expr import is_in_expr


def alex_2cde(
def compute_fret_2cde(
indices: pl.DataFrame,
kde_rates,
) -> np.ndarray:
output = np.empty(len(indices), dtype=float)
for idx, (imin, imax) in enumerate(indices.iter_rows()):
# df = kde_rates[imin:imax+1] # this is slower
df = kde_rates.slice(imin, imax - imin + 1)

# filtering like this is also slower
# df_f_DA = df.filter(f_DA) # select only the DA photons, thus these are for KDE^X_DA
# df_f_DD = df.filter(f_DD) # KDE^X_DD (density of ch X at DD timestamps)

# kde_DA_DA = df_f_DA['DA'] # select DA density - kde^DA_DA
# kde_DA_DD = df_f_DD['DA'] # kde^DA_DD (in the paper called kde^A_D)
# kde_DD_DD = df_f_DD['DD'] # kde^DD_DD
# kde_DD_DA = df_f_DA['DD'] # kde^DD_DA

bools_DA = df["stream"] == "DA"
bools_DD = df["stream"] == "DD"

kde_DA_DA = df["DA"].filter(bools_DA) # select DA density - kde^DA_DA
kde_DA_DD = df["DA"].filter(bools_DD) # kde^DA_DD (in the paper called kde^A_D)
kde_DD_DD = df["DD"].filter(bools_DD) # kde^DD_DD
kde_DD_DA = df["DD"].filter(bools_DA) # kde^DD_DA

try:
nbkde_DA_DA = (1 + 2 / len(kde_DA_DA)) * (kde_DA_DA - 1)
nbkde_DD_DD = (1 + 2 / len(kde_DD_DD)) * (kde_DD_DD - 1)

# ED = (kde_DA_DD / (kde_DA_DD + nbkde_DD_DD)).sum() / nbkde_DD_DD.count()
# EA = (kde_DD_DA / (kde_DD_DA + nbkde_DA_DA)).sum() / nbkde_DA_DA.count() # = (1 - E)_A

# when denom is zero, it doesnt count towards number of photons
# see "Such cases are removed by the computer algorithm", in Tomov et al.
denom = kde_DA_DD + nbkde_DD_DD
ED = (kde_DA_DD / denom).sum() / (denom != 0.0).sum() # when

denom = kde_DD_DA + nbkde_DA_DA
EA = (kde_DD_DA / denom).sum() / (denom != 0.0).sum() # = (1 - E)_A

fret_cde = 110 - 100 * (ED + EA)
output[idx] = fret_cde
except ZeroDivisionError:
output[idx] = np.nan

return output


# refactor to indices / loop
def compute_alex_2cde(
burst_photons: pl.DataFrame,
kde_rates: pl.DataFrame,
dex_streams: Optional[list[str]] = None,
Expand All @@ -39,11 +89,14 @@ def alex_2cde(

# equivalent to (but faster):
# joined_df = burst_photons.join(kde_rates, on=['timestamps', 'stream'], how='inner')
j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left")
joined_df = j1.filter(pl.col("stream") == pl.col("stream_right"))
# joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter(
# pl.col("stream") == pl.col("stream_right")
# )
# still, this is the slowest step. would be nice if we can improve since we know the indices

# remove comments after passing test:
# j1 = burst_photons.join(kde_rates, on=["timestamps"], how="left")
# joined_df = j1.filter(pl.col("stream") == pl.col("stream_right"))
joined_df = burst_photons.join(kde_rates, on=["timestamps"], how="left").filter(
pl.col("stream") == pl.col("stream_right")
)

b_df = joined_df.select(
[
Expand All @@ -67,9 +120,12 @@ def alex_2cde(
combined = pl.concat([agg_dex, agg_dax], how="align")

# tomov et al eqn 12
# this is an addition in the paper
# in fretbursts its subtracted
ax_2cde_bracket = (1 / pl.col("N_aex")) * pl.col("ratio_AD") - (1 / pl.col("N_dex")) * pl.col(
"ratio_DA"
)

ax_2cde_norm = pl.lit(100) - pl.lit(50) * ax_2cde_bracket

alex_2cde = combined.select(ax_2cde_norm.alias("alex_2cde")).to_series()
Expand Down Expand Up @@ -104,6 +160,9 @@ def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray)
event_times = df["timestamps"].to_numpy(allow_copy=True)
eval_times = data["timestamps"].to_numpy(allow_copy=True)

# event_times = np.array(df["timestamps"])
# eval_times = np.array(data["timestamps"])

return async_convolve(event_times, eval_times, kernel)


Expand All @@ -113,8 +172,7 @@ def convolve_stream(data: pl.DataFrame, streams: list[str], kernel: np.ndarray)
types.Array(int64, 1, "C", readonly=True),
types.Array(float64, 1, "C", readonly=True),
),
nopython=True,
cache=True,
nopython=False,
nogil=True,
)
def async_convolve(event_times, eval_times, kernel):
Expand All @@ -128,12 +186,16 @@ def async_convolve(event_times, eval_times, kernel):

for i, time in enumerate(eval_times):
while event_times[i_upper] < time + window_half_size:
i_upper += 1
if i_upper == len(event_times):
i_upper -= 1
break
i_upper += 1

while event_times[i_lower] < time - window_half_size:
i_lower += 1
if i_lower == len(event_times):
i_lower -= 1
break

for j in range(i_lower, i_upper):
idx = event_times[j] - time + window_half_size
Expand Down
12 changes: 10 additions & 2 deletions dont_fret/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,18 +250,21 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts:
# Check if any of the times _items is empty, if so, bursts is empty
if any(len(t) == 0 for t in times_list):
burst_photons = pl.DataFrame({k: [] for k in self.data.columns + ["burst_index"]})
indices = pl.DataFrame({"imin": [], "imax": []})
else:
# Take the intersection of the time intervals found by the multi-color burst search
final_times = reduce(return_intersections, times_list)

if len(final_times) == 0: # No overlap found
burst_photons = pl.DataFrame({k: [] for k in self.data.columns + ["burst_index"]})
indices = pl.DataFrame({"imin": [], "imax": []})
else:
tmin, tmax = np.array(final_times).T

# Convert back to indices
imin = np.searchsorted(self.timestamps, tmin)
imax = np.searchsorted(self.timestamps, tmax)
indices = pl.DataFrame({"imin": imin, "imax": imax})
# take all photons (up to and including? edges need to be checked!)
b_num = int(2 ** np.ceil(np.log2((np.log2(len(imin))))))
dtype = getattr(pl, f"UInt{b_num}", pl.Int32)
Expand All @@ -271,7 +274,7 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts:
]
burst_photons = pl.concat(bursts)

bs = Bursts(burst_photons, metadata=self.metadata)
bs = Bursts(burst_photons, indices=indices, metadata=self.metadata)

return bs

Expand Down Expand Up @@ -364,9 +367,14 @@ class Bursts(object):
# todo add metadata support
# bursts: numpy.typing.ArrayLike[Bursts] ?
def __init__(
self, photon_data: pl.DataFrame, metadata: Optional[dict] = None, cfg: DontFRETConfig = cfg
self,
photon_data: pl.DataFrame,
indices: pl.DataFrame,
metadata: Optional[dict] = None,
cfg: DontFRETConfig = cfg,
):
self.photon_data = photon_data
self.indices = indices
self.metadata: dict = metadata or {}
self.cfg = cfg

Expand Down
28 changes: 22 additions & 6 deletions templates/07_c_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
import matplotlib.pyplot as plt
import polars as pl

from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel
from dont_fret.channel_kde import compute_alex_2cde, compute_fret_2cde, convolve_stream, make_kernel
from dont_fret.expr import is_in_expr
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData

# %%
cwd = Path(__file__).parent
test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1"
test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds3"
ptu_file = "datafile_1.ptu"
ptu_file = "210122_sFRET_MBP5K_apo_15.ptu"

# %%
photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file))
bursts = photons.burst_search("APBS")

tau = 50e-6
assert photons.timestamps_unit
Expand All @@ -24,20 +27,33 @@
ax.plot(kernel)

# %%

# tomov et al eqn 9
D_ex = convolve_stream(photons.data, ["DD", "DA"], kernel)
A_ex = convolve_stream(photons.data, ["AA"], kernel)
A_ex = convolve_stream(photons.data, ["AA"], kernel) # crashed the kernel (sometimes)

# %%

kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(D_ex).alias("D_ex"), pl.lit(A_ex).alias("A_ex")]
)
alex_2cde = compute_alex_2cde(bursts.photon_data, kde_data)

bursts = photons.burst_search("APBS")
DA = convolve_stream(photons.data, ["DA"], kernel)
DD = convolve_stream(photons.data, ["DD"], kernel)
kde_data = photons.data.select(
[pl.col("timestamps"), pl.col("stream"), pl.lit(DA).alias("DA"), pl.lit(DD).alias("DD")]
)

# %%

fret_2cde = compute_fret_2cde(bursts.indices, kde_data)

fret_2cde

# %%
alex_2cde_vals = alex_2cde(bursts.photon_data, kde_data)

fig, axes = plt.subplots(ncols=2)
axes[0].hist(alex_2cde, bins="fd")
axes[1].hist(fret_2cde, bins="fd")
# axes[1].axvline(10, color='r')
# %%
4 changes: 2 additions & 2 deletions tests/test_c_kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import polars as pl
import polars.testing as pl_test

from dont_fret.channel_kde import alex_2cde, convolve_stream, make_kernel
from dont_fret.channel_kde import compute_alex_2cde, convolve_stream, make_kernel
from dont_fret.fileIO import PhotonFile
from dont_fret.models import PhotonData

Expand Down Expand Up @@ -44,5 +44,5 @@ def test_compare_cde_fretbursts():

burst_data = pl.concat(burst_dfs, how="vertical_relaxed")

alex_2cde_vals = alex_2cde(burst_data, kde_data)
alex_2cde_vals = compute_alex_2cde(burst_data, kde_data)
pl_test.assert_series_equal(alex_2cde_vals.fill_null(np.nan), bursts_ref["alex_2cde"])

0 comments on commit 7f21537

Please sign in to comment.