Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement alternative bout detection method #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions bouter/free/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,13 @@ def compute_velocity(
return fish_velocities

@decorators.cache_results()
def get_bouts(self, scale=None, threshold=1, **kwargs):
def get_bouts(self, scale=None, threshold=1, conv_detection=False, **kwargs):
"""Extracts all bouts from a freely-swimming tracking experiment

:param exp: the experiment object
:param scale: mm per pixel, recalculated by default
:param threshold: velocity threshold in mm/s
:param threshold: velocity threshold in mm/s or score threshold if conv_detection=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if the fact that this threshold is provided via the same argument can lead to small issues with people not realizing the thresholds for both methods are in different units, and therefore not detecting bouts properly.
One could split this into two different threshold arguments, but maybe it makes the function uglier with lots of arguments. Alternatively, maybe adding a warning if your method is used, pointing towards thinking about the provided threshold?

:param conv_detection: whether to use an alternative detection algorithm using convolution.
:return: tuple: (list of single bout dataframes, list of boolean arrays marking if the
bout i follows bout i-1)
"""
Expand All @@ -156,12 +157,23 @@ def get_bouts(self, scale=None, threshold=1, **kwargs):

for i_fish in range(n_fish):
vel2 = fish_velocities["vel_f{}".format(i_fish)]
(
bout_locations,
continuity,
) = utilities.extract_segments_above_threshold(
vel2.values, threshold=threshold**2, **kwargs
)
if not conv_detection:
(
bout_locations,
continuity,
) = utilities.extract_segments_above_threshold(
vel2.values, threshold=threshold**2, **kwargs
)
else:
score = utilities.calc_bout_score(vel2.values)
bout_times = utilities.get_bout_times(
score, min_peak_value=threshold, **kwargs
)

# For compatability.
bout_locations = np.array(bout_times)
bout_locations = bout_locations[:, [0, 2]]
continuity = [False] * bout_locations.shape[0]
all_bouts_fish = [
self._extract_bout(s, e, n_segments, i_fish, scale)
for s, e in bout_locations
Expand Down
162 changes: 162 additions & 0 deletions bouter/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,168 @@ def extract_segments_above_threshold(

return np.array(segments), np.array(connected)

@jit(nopython=True)
def get_score_trace(trace_pad, kernel, bias=-0.2839):
"""
Used to calculate the correlation score between the squared velocity
trace and the kernel for bout detection.
:param trace_pad: squared velocity (optionally) padded
:param kernel: kernel used for bout detection
:return:
array of correlation scores that matches with the trace by index.
"""

# For numerical stability.
eta = 0.0000000001

corr = np.empty((len(trace_pad) - len(kernel)))
for i in range(len(trace_pad) - len(kernel)):
current_values = trace_pad[i:i+len(kernel)]
relative_values = current_values / (np.max(current_values) + eta)

# Simplified convolution, works faster and with Numba.
conv = np.sum(relative_values * np.flip(kernel))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only issue with using convolutions is one needs to be aware that, depending on the threshold used, bouts may end up being detected earlier than they occurred. If possible, I would try to make clear somewhere in the docs that one needs to keep this in mind when processing and analyzing the data.


corr[i] = 1 / (1 + np.exp(conv + bias))

# Align the correlation score trace with the actual bouts.
# Found by trial and error - so can be changed as needed.
corr = corr[:-int(kernel.shape[0]/5)]
corr = np.concatenate((np.zeros(int(kernel.shape[0]/5)), corr))

return (corr - 1) * -1


def calc_bout_score(trace, kernel=None, bias=-0.2839, pad_len=None):
"""
Wrapper for the get_score_trace function.
:param trace: the trace to detect bouts in.
:param kernel: kernel used for bout detection.
:param bias: the bias to shift the output
:param pad_len: the length of the padding on each side.
:return:
array of correlation scores that matches with the trace by index.
"""

if kernel is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An issue I see with this approach is that the kernel is probably quite dependent on the rate at what your data has been sampled. One could up/downsample this kernel, but then it should at least be indicated the sampling rate of the data it is based on, and instead of hardcoded in the function, maybe it could be stored in an assets/kernels/whatever folder from where one can load it and modify it before using it, rather than just having to alter it within the function.
Maybe (big maybe) adding a function/instruction on how to generate one's own?

# Kernel trained using a NN.
kernel = np.array([
-0.5633, -0.5888, -0.5097, -0.3899, -0.4896, -0.4326, -0.5017,
-0.4555, -0.4709, -0.4580, -0.4563, -0.4510, -0.3165, -0.3896,
-0.2522, -0.2982, -0.2503, -0.0670, 0.1450, 0.2824, 0.2420,
0.3270, 0.2569, 0.3174, 0.3718, 0.2896, 0.3356, 0.3905,
0.2844, 0.3674, 0.3158, 0.3549, 0.2847, 0.4782, 0.4236,
0.4416, 0.4049, 0.3622, 0.3755, 0.2569, 0.2525, 0.2626,
0.2875, 0.1809, 0.1314, 0.1213, 0.1023, -0.0113, 0.1013,
0.0844, -0.0785, -0.0316, -0.0584, -0.1613, -0.1835, -0.1843,
-0.1421, -0.1407, -0.0838, -0.1655, -0.2004, -0.0968, -0.1559,
-0.1564, -0.1867, -0.1494, -0.1192, -0.2535, -0.1645, -0.1529,
-0.1918, -0.1987, -0.2686, -0.2107, -0.2132, -0.2063, -0.2253,
-0.1670, -0.2638, -0.2669, -0.1228, -0.1679, -0.2795, -0.2066,
-0.1625, -0.1498, -0.1983, -0.2351, -0.2337, -0.2803, -0.3189,
-0.2672, -0.2565, -0.3481, -0.3722, -0.3055, -0.3174, -0.4059,
-0.3363, -0.4085])

if pad_len is None:
pad_len = int(kernel.shape[0]/2)

trace_pad = np.pad(trace, pad_len)
return get_score_trace(trace_pad, kernel)


def get_bout_times(trace,
min_peak_value=0.75,
max_baseline_value=0.05,
include_nan=False,
max_zero_length=(55, 55),
min_bout_distance=0,
**kwargs):
"""Finds bout peaks and their start and end from a convolved
trace (only tested on freely-swimming experiments).
:param trace: the squared velocity trace convolved with a bout detection kernel.
:param min_peak_value: the minimum peak value for a bump to count as a bout.
In NN terms this is the value of the sigmoid for classification, so 0.5
Would be the cut-off value. However, one can opt for a higher or lower value
if they want to change the false positive or true negative rate (i.e. a higher
value should classify less noise as bouts but also means that more bouts will
not be detected).
:param max_baseline_value: (ab)uses the fact that a convolution will result in a smooth
signal. The sides of the peak more or less reflect the start and end times of the
bout. This is the cut-off value that determines at which points the peak start/ends
and thus the bout starts/ends.
:param include_nan: include bouts containing NaN values.
:param max_length_to_baseline: the maximum number of samples that it may take to reach the
baseline (defined by max_baseline_value) from the peak.
:param min_bout_distance: minimum distance between the end of a bout and the peak of the
next bout.
:return: tuple: (bout start times, bout peak times, bout end times)
"""

last_bout_end = 0
bouts = []

i = 0
while i < trace.shape[0]:
# Check if the value is high enough to indicate a bout.
if trace[i] > min_peak_value:
# We found a bout!
bout_start = i
bout_end = i

# Check how far back it goes.
j = i + 1
found_start = False
while j > last_bout_end:
if trace[j] < max_baseline_value:
bout_start = j
found_start = True
break

j -= 1

if not found_start:
bout_start = last_bout_end

# Check how far the end is.
j = i + 1
found_end = False
while j < trace.shape[0]:
if trace[j] < max_baseline_value:
bout_end = j
found_end = True
break
elif trace[j-1] < min_peak_value < trace[j]:
# We found the next bout!
# We stop this bout at the valley.
bout_end = i + np.argmin(trace[i:j])
found_end = True
break

j += 1

if not found_end:
bout_end = trace.shape[0] - 1

# Update values.
last_bout_end = bout_end
# Filter out bouts containing NaNs.
if include_nan or not np.isnan(trace[bout_start:bout_end]).any():
# The convolution during noise is less steep,
# it might not reach the baseline value within the typical
# bout time window. So we can use the requirement of the 'bout'
# stopping within a certain time range, as a way to filter noise.
if max_zero_length is None \
or (max_zero_length[0] > i - bout_start \
or max_zero_length[1] > bout_end - i):
bouts.append((bout_start, i, bout_end))

# Skip the remaining part of the bout and optionally some extra distance.
i = bout_end + min_bout_distance

i += 1

return bouts


def log_dt(log_df, i_start=10, i_end=110):
return np.mean(np.diff(log_df.t[i_start:i_end]))
Expand Down