Skip to content

Commit

Permalink
Improve stability of compressed spectrogram generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bluemellophone committed Jan 19, 2024
1 parent d809c13 commit 77f2306
Showing 1 changed file with 94 additions and 74 deletions.
168 changes: 94 additions & 74 deletions bats_ai/core/models/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,87 +161,107 @@ def generate(cls, recording):
def compressed(self):
img = self.image_np

canvas = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
canvas = canvas.astype(np.float32)

amplitude = canvas.max(axis=0)
amplitude -= amplitude.min()
amplitude /= amplitude.max()
amplitude[amplitude < 0.5] = 0.0
amplitude[amplitude > 0] = 1.0
amplitude = amplitude.reshape(1, -1)

canvas -= canvas.min()
canvas /= canvas.max()
canvas *= 255.0
canvas *= amplitude
canvas = np.around(canvas).astype(np.uint8)

mask = canvas.max(axis=0)
mask = scipy.signal.medfilt(mask, 3)
mask[0] = 0
mask[-1] = 0
starts = []
stops = []
for index in range(1, len(mask) - 1):
value_pre = mask[index - 1]
value = mask[index]
value_post = mask[index + 1]
if value != 0:
if value_pre == 0:
starts.append(index)
if value_post == 0:
stops.append(index)
assert len(starts) == len(stops)

starts = [val - 40 for val in starts] # 10 ms buffer
stops = [val + 40 for val in stops] # 10 ms buffer
ranges = list(zip(starts, stops))

threshold = 0.5
while True:
found = False
merged = []
index = 0
while index < len(ranges) - 1:
start1, stop1 = ranges[index]
start2, stop2 = ranges[index + 1]
if stop1 >= start2:
found = True
merged.append((start1, stop2))
index += 2
else:
merged.append((start1, stop1))
index += 1
canvas = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
canvas = canvas.astype(np.float32)

amplitude = canvas.max(axis=0)
amplitude -= amplitude.min()
amplitude /= amplitude.max()
amplitude[amplitude < threshold] = 0.0
amplitude[amplitude > 0] = 1.0
amplitude = amplitude.reshape(1, -1)

canvas -= canvas.min()
canvas /= canvas.max()
canvas *= 255.0
canvas *= amplitude
canvas = np.around(canvas).astype(np.uint8)

mask = canvas.max(axis=0)
mask = scipy.signal.medfilt(mask, 3)
mask[0] = 0
mask[-1] = 0
starts = []
stops = []
for index in range(1, len(mask) - 1):
value_pre = mask[index - 1]
value = mask[index]
value_post = mask[index + 1]
if value != 0:
if value_pre == 0:
starts.append(index)
if value_post == 0:
stops.append(index)
assert len(starts) == len(stops)

starts = [val - 40 for val in starts] # 10 ms buffer
stops = [val + 40 for val in stops] # 10 ms buffer
ranges = list(zip(starts, stops))

while True:
found = False
merged = []
index = 0
while index < len(ranges) - 1:
start1, stop1 = ranges[index]
start2, stop2 = ranges[index + 1]

start1 = min(max(start1, 0), len(mask))
start2 = min(max(start2, 0), len(mask))
stop1 = min(max(stop1, 0), len(mask))
stop2 = min(max(stop2, 0), len(mask))

if stop1 >= start2:
found = True
merged.append((start1, stop2))
index += 2
else:
merged.append((start1, stop1))
index += 1
if index == len(ranges) - 1:
merged.append((start2, stop2))
ranges = merged
if not found:
for index in range(1, len(ranges)):
start1, stop1 = ranges[index - 1]
start2, stop2 = ranges[index]
assert start1 < stop1
assert start2 < stop2
assert start1 < start2
assert stop1 < stop2
assert stop1 < start2
ranges = merged
if not found:
for index in range(1, len(ranges)):
start1, stop1 = ranges[index - 1]
start2, stop2 = ranges[index]
assert start1 < stop1
assert start2 < stop2
assert start1 < start2
assert stop1 < stop2
assert stop1 < start2
break

segments = []
starts_ = []
stops_ = []
domain = img.shape[1]
for start, stop in ranges:
segment = img[:, start:stop]
segments.append(segment)

starts_.append(int(round(self.duration * (start / domain))))
stops_.append(int(round(self.duration * (stop / domain))))

# buffer = np.zeros((len(img), 20, 3), dtype=img.dtype)
# segments.append(buffer)
# segments = segments[:-1]

if len(segments) > 0:
break

segments = []
starts_ = []
stops_ = []
domain = img.shape[1]
for start, stop in ranges:
segment = img[:, start:stop]
segments.append(segment)

starts_.append(int(round(self.duration * (start / domain))))
stops_.append(int(round(self.duration * (stop / domain))))
threshold -= 0.05
if threshold < 0:
segments = None
break

# buffer = np.zeros((len(img), 20, 3), dtype=img.dtype)
# segments.append(buffer)
# segments = segments[:-1]
if segments is None:
canvas = img.copy()
else:
canvas = np.hstack(segments)

canvas = np.hstack(segments)
canvas = Image.fromarray(canvas, 'RGB')

# canvas.save('temp.jpg')
Expand Down

0 comments on commit 77f2306

Please sign in to comment.