Skip to content

Commit

Permalink
implement classify stage w/ MaaS approach
Browse files Browse the repository at this point in the history
  • Loading branch information
pmhalvor committed Sep 28, 2024
1 parent 2d4b087 commit eba95dd
Show file tree
Hide file tree
Showing 5 changed files with 321 additions and 131 deletions.
194 changes: 194 additions & 0 deletions examples/classify_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
"""
Beam PTransform for classifying whale audio.
Gets stuck on classification, either due to memroy issues or model serialization.
Kept for reference, but replaced by InferenceClient in classify.py.
"""
from apache_beam.io import filesystems
from datetime import datetime

import apache_beam as beam
import io
import librosa
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy
import time
import tensorflow_hub as hub
import tensorflow as tf

from config import load_pipeline_config


config = load_pipeline_config()


class BaseClassifier(beam.PTransform):
name = "BaseClassifier"


def __init__(self):
self.source_sample_rate = config.audio.source_sample_rate
self.model_sample_rate = config.classify.model_sample_rate

self.model_path = config.classify.model_path

def _preprocess(self, pcoll):
signal, start, end, encounter_ids = pcoll
key = self._build_key(start, end, encounter_ids)

# Resample
signal = self._resample(signal)

batch_samples = self.batch_duration * self.sample_rate

if signal.shape[0] > batch_samples:
logging.debug(f"Signal size exceeds max sample size {batch_samples}.")

split_indices = [batch_samples*(i+1) for i in range(math.floor(signal.shape[0] / batch_samples))]
signal_batches = np.array_split(signal, split_indices)
logging.debug(f"Split signal into {len(signal_batches)} batches of size {batch_samples}.")
logging.debug(f"Size fo final batch {len(signal_batches[1])}")

for batch in signal_batches:
yield (key, batch)
else:
yield (key, signal)

def _build_key(
self,
start_time: datetime,
end_time: datetime,
encounter_ids: list,
):
start_str = start_time.strftime('%Y%m%dT%H%M%S')
end_str = end_time.strftime('%H%M%S')
encounter_str = "_".join(encounter_ids)
return f"{start_str}-{end_str}_{encounter_str}"

def _postprocess(self, pcoll):
return pcoll

def _get_model(self):
model = hub.load(self.model_path)
return model

def _resample(self, signal):
logging.info(
f"Resampling signal from {self.source_sample_rate} to {self.model_sample_rate}")
return librosa.resample(
signal,
orig_sr=self.source_sample_rate,
target_sr=self.model_sample_rate
)


class GoogleHumpbackWhaleClassifier(BaseClassifier):
"""
Model docs: https://tfhub.dev/google/humpback_whale/1
"""
name = "GoogleHumpbackWhaleClassifier"

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model = self._get_model()
self.score_fn = self.model.signatures['score']
self.metadata_fn = self.model.signatures['metadata']

def expand(self, pcoll):
return (
pcoll
| "Preprocess" >> beam.Map(self._preprocess)
| "Classify" >> beam.Map(self._classify)
| "Postprocess" >> beam.Map(self._postprocess)
)

def _classify(self, pcoll, ):
key, signal = pcoll

start_classify = time.time()

# We specify a 1-sec score resolution:
context_step_samples = tf.cast(self.model_sample_rate, tf.int64)

logging.info(f'\n==> Applying model ...')
logging.debug(f' inital input: len(signal_10kHz) = {len(signal)}')

waveform1 = np.expand_dims(signal, axis=1)
waveform_exp = tf.expand_dims(waveform1, 0) # makes a batch of size 1
logging.debug(f" final input: waveform_exp.shape = {waveform_exp.shape}")

signal_scores = self.score_fn(
waveform=waveform_exp,
context_step_samples=context_step_samples
)
score_values = signal_scores['scores'].numpy()[0]
logging.info(f'==> Model applied. Obtained {len(score_values)} score_values')
logging.info(f'==> Elapsed time: {time.time() - start_classify} seconds')

return (key, score_values)

def _plot_spectrogram_scipy(self, signal, epsilon = 1e-15):
# Compute spectrogram:
w = scipy.signal.get_window('hann', self.sample_rate)
f, t, psd = scipy.signal.spectrogram(
signal, # TODO make sure this is resampled signal
self.model_sample_rate,
nperseg=self.model_sample_rate,
noverlap=0,
window=w,
nfft=self.model_sample_rate,
)
psd = 10*np.log10(psd+epsilon) - self.hydrophone_sensitivity

# Plot spectrogram:
fig = plt.figure(figsize=(20, round(20/3))) # 3:1 aspect ratio
plt.imshow(
psd,
aspect='auto',
origin='lower',
vmin=30,
vmax=90,
cmap='Blues',
)
plt.yscale('log')
y_max = self.model_sample_rate / 2
plt.ylim(10, y_max)

plt.colorbar()

plt.xlabel('Seconds')
plt.ylabel('Frequency (Hz)')
plt.title(f'Calibrated spectrum levels, 16 {self.sample_rate / 1000.0} kHz data')

def _plot_scores(self, pcoll, scores, med_filt_size=None):
audio, start, end, encounter_ids = pcoll
key = self._build_key(start, end, encounter_ids)

# repeat last value to also see a step at the end:
scores = np.concatenate((scores, scores[-1:]))
x = range(len(scores))
plt.step(x, scores, where='post')
plt.plot(x, scores, 'o', color='lightgrey', markersize=9)

plt.grid(axis='x', color='0.95')
plt.xlim(xmin=0, xmax=len(scores) - 1)
plt.ylabel('Model Score')
plt.xlabel('Seconds')

if med_filt_size is not None:
scores_int = [int(s[0]*1000) for s in scores]
meds_int = scipy.signal.medfilt(scores_int, kernel_size=med_filt_size)
meds = [m/1000. for m in meds_int]
plt.plot(x, meds, 'p', color='black', markersize=9)

plot_path = config.classify.plot_path_template.format(
year=start.year,
month=start.month,
day=start.day,
plot_name=key
)
plt.savefig(plot_path)
plt.show()
22 changes: 12 additions & 10 deletions src/pipeline/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from stages.search import GeometrySearch
from stages.audio import RetrieveAudio, WriteAudio, WriteSiftedAudio
from stages.sift import Butterworth
from stages.classify import WhaleClassifier

from config import load_pipeline_config
config = load_pipeline_config()
Expand All @@ -18,27 +19,28 @@ def run():
}

with beam.Pipeline(options=pipeline_options) as p:
input_data = p | "Create Input" >> beam.Create([args])
search_results = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch())
input_data = p | "Create Input" >> beam.Create([args])
search_output = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch())

audio_results = search_results | "Retrieve Audio" >> beam.ParDo(RetrieveAudio())
audio_files = audio_results | "Store Audio (temp)" >> beam.ParDo(WriteAudio())
audio_output = search_output | "Retrieve Audio" >> beam.ParDo(RetrieveAudio())
audio_files = audio_output | "Store Audio (temp)" >> beam.ParDo(WriteAudio())

sifted_audio = audio_results | "Sift Audio" >> Butterworth()
sifted_audio = audio_output | "Sift Audio" >> Butterworth()
sifted_audio_files = sifted_audio | "Store Sifted Audio" >> beam.ParDo(WriteSiftedAudio("butterworth"))

# For debugging, you can write the output to a text file
# audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt')
# search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt')
classifications = sifted_audio | "Classify Audio" >> WhaleClassifier(config)


# classified_audio = filtered_audio | "Classify Audio" >> ClassifyAudio()

# # Post-process the labels
# postprocessed_labels = classified_audio | "Postprocess Labels" >> PostprocessLabels()

# Output results
# postprocessed_labels | "Write Results" >> beam.io.WriteToText("output.txt")

# For debugging, you can write the output to a text file
# audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt')
# search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt')


if __name__ == "__main__":
run()
2 changes: 2 additions & 0 deletions src/pipeline/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pipeline:
output_path_template: "data/audio/{sift}/{year}/{month:02}/{filename}"
max_duration: 600 # seconds
plot: true
show_plot: false
plot_path_template: "data/plots/{sift}/{year}/{month:02}/{day:02}/{plot_name}.png"
window_size: 512

Expand All @@ -54,6 +55,7 @@ pipeline:
model_sample_rate: 10000
plot_path_template: "data/plots/results/{year}/{month:02}/{plot_name}.png"
url: https://tfhub.dev/google/humpback_whale/1
model_url: "http://127.0.0.1:5000/predict"

postprocess:
min_gap: 60 # 1 minute
Expand Down
Loading

0 comments on commit eba95dd

Please sign in to comment.