-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement alternative bout detection method #15
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,168 @@ def extract_segments_above_threshold( | |
|
||
return np.array(segments), np.array(connected) | ||
|
||
@jit(nopython=True) | ||
def get_score_trace(trace_pad, kernel, bias=-0.2839): | ||
""" | ||
Used to calculate the correlation score between the squared velocity | ||
trace and the kernel for bout detection. | ||
:param trace_pad: squared velocity (optionally) padded | ||
:param kernel: kernel used for bout detection | ||
:return: | ||
array of correlation scores that matches with the trace by index. | ||
""" | ||
|
||
# For numerical stability. | ||
eta = 0.0000000001 | ||
|
||
corr = np.empty((len(trace_pad) - len(kernel))) | ||
for i in range(len(trace_pad) - len(kernel)): | ||
current_values = trace_pad[i:i+len(kernel)] | ||
relative_values = current_values / (np.max(current_values) + eta) | ||
|
||
# Simplified convolution, works faster and with Numba. | ||
conv = np.sum(relative_values * np.flip(kernel)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only issue with using convolutions is one needs to be aware that, depending on the threshold used, bouts may end up being detected earlier than they occurred. If possible, I would try to make clear somewhere in the docs that one needs to keep this in mind when processing and analyzing the data. |
||
|
||
corr[i] = 1 / (1 + np.exp(conv + bias)) | ||
|
||
# Align the correlation score trace with the actual bouts. | ||
# Found by trial and error - so can be changed as needed. | ||
corr = corr[:-int(kernel.shape[0]/5)] | ||
corr = np.concatenate((np.zeros(int(kernel.shape[0]/5)), corr)) | ||
|
||
return (corr - 1) * -1 | ||
|
||
|
||
def calc_bout_score(trace, kernel=None, bias=-0.2839, pad_len=None): | ||
""" | ||
Wrapper for the get_score_trace function. | ||
:param trace: the trace to detect bouts in. | ||
:param kernel: kernel used for bout detection. | ||
:param bias: the bias to shift the output | ||
:param pad_len: the length of the padding on each side. | ||
:return: | ||
array of correlation scores that matches with the trace by index. | ||
""" | ||
|
||
if kernel is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An issue I see with this approach is that the kernel is probably quite dependent on the rate at what your data has been sampled. One could up/downsample this kernel, but then it should at least be indicated the sampling rate of the data it is based on, and instead of hardcoded in the function, maybe it could be stored in an assets/kernels/whatever folder from where one can load it and modify it before using it, rather than just having to alter it within the function. |
||
# Kernel trained using a NN. | ||
kernel = np.array([ | ||
-0.5633, -0.5888, -0.5097, -0.3899, -0.4896, -0.4326, -0.5017, | ||
-0.4555, -0.4709, -0.4580, -0.4563, -0.4510, -0.3165, -0.3896, | ||
-0.2522, -0.2982, -0.2503, -0.0670, 0.1450, 0.2824, 0.2420, | ||
0.3270, 0.2569, 0.3174, 0.3718, 0.2896, 0.3356, 0.3905, | ||
0.2844, 0.3674, 0.3158, 0.3549, 0.2847, 0.4782, 0.4236, | ||
0.4416, 0.4049, 0.3622, 0.3755, 0.2569, 0.2525, 0.2626, | ||
0.2875, 0.1809, 0.1314, 0.1213, 0.1023, -0.0113, 0.1013, | ||
0.0844, -0.0785, -0.0316, -0.0584, -0.1613, -0.1835, -0.1843, | ||
-0.1421, -0.1407, -0.0838, -0.1655, -0.2004, -0.0968, -0.1559, | ||
-0.1564, -0.1867, -0.1494, -0.1192, -0.2535, -0.1645, -0.1529, | ||
-0.1918, -0.1987, -0.2686, -0.2107, -0.2132, -0.2063, -0.2253, | ||
-0.1670, -0.2638, -0.2669, -0.1228, -0.1679, -0.2795, -0.2066, | ||
-0.1625, -0.1498, -0.1983, -0.2351, -0.2337, -0.2803, -0.3189, | ||
-0.2672, -0.2565, -0.3481, -0.3722, -0.3055, -0.3174, -0.4059, | ||
-0.3363, -0.4085]) | ||
|
||
if pad_len is None: | ||
pad_len = int(kernel.shape[0]/2) | ||
|
||
trace_pad = np.pad(trace, pad_len) | ||
return get_score_trace(trace_pad, kernel) | ||
|
||
|
||
def get_bout_times(trace, | ||
min_peak_value=0.75, | ||
max_baseline_value=0.05, | ||
include_nan=False, | ||
max_zero_length=(55, 55), | ||
min_bout_distance=0, | ||
**kwargs): | ||
"""Finds bout peaks and their start and end from a convolved | ||
trace (only tested on freely-swimming experiments). | ||
:param trace: the squared velocity trace convolved with a bout detection kernel. | ||
:param min_peak_value: the minimum peak value for a bump to count as a bout. | ||
In NN terms this is the value of the sigmoid for classification, so 0.5 | ||
Would be the cut-off value. However, one can opt for a higher or lower value | ||
if they want to change the false positive or true negative rate (i.e. a higher | ||
value should classify less noise as bouts but also means that more bouts will | ||
not be detected). | ||
:param max_baseline_value: (ab)uses the fact that a convolution will result in a smooth | ||
signal. The sides of the peak more or less reflect the start and end times of the | ||
bout. This is the cut-off value that determines at which points the peak start/ends | ||
and thus the bout starts/ends. | ||
:param include_nan: include bouts containing NaN values. | ||
:param max_length_to_baseline: the maximum number of samples that it may take to reach the | ||
baseline (defined by max_baseline_value) from the peak. | ||
:param min_bout_distance: minimum distance between the end of a bout and the peak of the | ||
next bout. | ||
:return: tuple: (bout start times, bout peak times, bout end times) | ||
""" | ||
|
||
last_bout_end = 0 | ||
bouts = [] | ||
|
||
i = 0 | ||
while i < trace.shape[0]: | ||
# Check if the value is high enough to indicate a bout. | ||
if trace[i] > min_peak_value: | ||
# We found a bout! | ||
bout_start = i | ||
bout_end = i | ||
|
||
# Check how far back it goes. | ||
j = i + 1 | ||
found_start = False | ||
while j > last_bout_end: | ||
if trace[j] < max_baseline_value: | ||
bout_start = j | ||
found_start = True | ||
break | ||
|
||
j -= 1 | ||
|
||
if not found_start: | ||
bout_start = last_bout_end | ||
|
||
# Check how far the end is. | ||
j = i + 1 | ||
found_end = False | ||
while j < trace.shape[0]: | ||
if trace[j] < max_baseline_value: | ||
bout_end = j | ||
found_end = True | ||
break | ||
elif trace[j-1] < min_peak_value < trace[j]: | ||
# We found the next bout! | ||
# We stop this bout at the valley. | ||
bout_end = i + np.argmin(trace[i:j]) | ||
found_end = True | ||
break | ||
|
||
j += 1 | ||
|
||
if not found_end: | ||
bout_end = trace.shape[0] - 1 | ||
|
||
# Update values. | ||
last_bout_end = bout_end | ||
# Filter out bouts containing NaNs. | ||
if include_nan or not np.isnan(trace[bout_start:bout_end]).any(): | ||
# The convolution during noise is less steep, | ||
# it might not reach the baseline value within the typical | ||
# bout time window. So we can use the requirement of the 'bout' | ||
# stopping within a certain time range, as a way to filter noise. | ||
if max_zero_length is None \ | ||
or (max_zero_length[0] > i - bout_start \ | ||
or max_zero_length[1] > bout_end - i): | ||
bouts.append((bout_start, i, bout_end)) | ||
|
||
# Skip the remaining part of the bout and optionally some extra distance. | ||
i = bout_end + min_bout_distance | ||
|
||
i += 1 | ||
|
||
return bouts | ||
|
||
|
||
def log_dt(log_df, i_start=10, i_end=110): | ||
return np.mean(np.diff(log_df.t[i_start:i_end])) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if the fact that this threshold is provided via the same argument can lead to small issues with people not realizing the thresholds for both methods are in different units, and therefore not detecting bouts properly.
One could split this into two different threshold arguments, but maybe it makes the function uglier with lots of arguments. Alternatively, maybe adding a warning if your method is used, pointing towards thinking about the provided threshold?