diff --git a/examples/classify_transform.py b/examples/classify_transform.py new file mode 100644 index 0000000..e2b54c9 --- /dev/null +++ b/examples/classify_transform.py @@ -0,0 +1,194 @@ +""" +Beam PTransform for classifying whale audio. +Gets stuck on classification, either due to memroy issues or model serialization. +Kept for reference, but replaced by InferenceClient in classify.py. +""" +from apache_beam.io import filesystems +from datetime import datetime + +import apache_beam as beam +import io +import librosa +import logging +import math +import matplotlib.pyplot as plt +import numpy as np +import os +import scipy +import time +import tensorflow_hub as hub +import tensorflow as tf + +from config import load_pipeline_config + + +config = load_pipeline_config() + + +class BaseClassifier(beam.PTransform): + name = "BaseClassifier" + + + def __init__(self): + self.source_sample_rate = config.audio.source_sample_rate + self.model_sample_rate = config.classify.model_sample_rate + + self.model_path = config.classify.model_path + + def _preprocess(self, pcoll): + signal, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + + # Resample + signal = self._resample(signal) + + batch_samples = self.batch_duration * self.sample_rate + + if signal.shape[0] > batch_samples: + logging.debug(f"Signal size exceeds max sample size {batch_samples}.") + + split_indices = [batch_samples*(i+1) for i in range(math.floor(signal.shape[0] / batch_samples))] + signal_batches = np.array_split(signal, split_indices) + logging.debug(f"Split signal into {len(signal_batches)} batches of size {batch_samples}.") + logging.debug(f"Size fo final batch {len(signal_batches[1])}") + + for batch in signal_batches: + yield (key, batch) + else: + yield (key, signal) + + def _build_key( + self, + start_time: datetime, + end_time: datetime, + encounter_ids: list, + ): + start_str = start_time.strftime('%Y%m%dT%H%M%S') + end_str = end_time.strftime('%H%M%S') + encounter_str = "_".join(encounter_ids) + return f"{start_str}-{end_str}_{encounter_str}" + + def _postprocess(self, pcoll): + return pcoll + + def _get_model(self): + model = hub.load(self.model_path) + return model + + def _resample(self, signal): + logging.info( + f"Resampling signal from {self.source_sample_rate} to {self.model_sample_rate}") + return librosa.resample( + signal, + orig_sr=self.source_sample_rate, + target_sr=self.model_sample_rate + ) + + +class GoogleHumpbackWhaleClassifier(BaseClassifier): + """ + Model docs: https://tfhub.dev/google/humpback_whale/1 + """ + name = "GoogleHumpbackWhaleClassifier" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = self._get_model() + self.score_fn = self.model.signatures['score'] + self.metadata_fn = self.model.signatures['metadata'] + + def expand(self, pcoll): + return ( + pcoll + | "Preprocess" >> beam.Map(self._preprocess) + | "Classify" >> beam.Map(self._classify) + | "Postprocess" >> beam.Map(self._postprocess) + ) + + def _classify(self, pcoll, ): + key, signal = pcoll + + start_classify = time.time() + + # We specify a 1-sec score resolution: + context_step_samples = tf.cast(self.model_sample_rate, tf.int64) + + logging.info(f'\n==> Applying model ...') + logging.debug(f' inital input: len(signal_10kHz) = {len(signal)}') + + waveform1 = np.expand_dims(signal, axis=1) + waveform_exp = tf.expand_dims(waveform1, 0) # makes a batch of size 1 + logging.debug(f" final input: waveform_exp.shape = {waveform_exp.shape}") + + signal_scores = self.score_fn( + waveform=waveform_exp, + context_step_samples=context_step_samples + ) + score_values = signal_scores['scores'].numpy()[0] + logging.info(f'==> Model applied. Obtained {len(score_values)} score_values') + logging.info(f'==> Elapsed time: {time.time() - start_classify} seconds') + + return (key, score_values) + + def _plot_spectrogram_scipy(self, signal, epsilon = 1e-15): + # Compute spectrogram: + w = scipy.signal.get_window('hann', self.sample_rate) + f, t, psd = scipy.signal.spectrogram( + signal, # TODO make sure this is resampled signal + self.model_sample_rate, + nperseg=self.model_sample_rate, + noverlap=0, + window=w, + nfft=self.model_sample_rate, + ) + psd = 10*np.log10(psd+epsilon) - self.hydrophone_sensitivity + + # Plot spectrogram: + fig = plt.figure(figsize=(20, round(20/3))) # 3:1 aspect ratio + plt.imshow( + psd, + aspect='auto', + origin='lower', + vmin=30, + vmax=90, + cmap='Blues', + ) + plt.yscale('log') + y_max = self.model_sample_rate / 2 + plt.ylim(10, y_max) + + plt.colorbar() + + plt.xlabel('Seconds') + plt.ylabel('Frequency (Hz)') + plt.title(f'Calibrated spectrum levels, 16 {self.sample_rate / 1000.0} kHz data') + + def _plot_scores(self, pcoll, scores, med_filt_size=None): + audio, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + + # repeat last value to also see a step at the end: + scores = np.concatenate((scores, scores[-1:])) + x = range(len(scores)) + plt.step(x, scores, where='post') + plt.plot(x, scores, 'o', color='lightgrey', markersize=9) + + plt.grid(axis='x', color='0.95') + plt.xlim(xmin=0, xmax=len(scores) - 1) + plt.ylabel('Model Score') + plt.xlabel('Seconds') + + if med_filt_size is not None: + scores_int = [int(s[0]*1000) for s in scores] + meds_int = scipy.signal.medfilt(scores_int, kernel_size=med_filt_size) + meds = [m/1000. for m in meds_int] + plt.plot(x, meds, 'p', color='black', markersize=9) + + plot_path = config.classify.plot_path_template.format( + year=start.year, + month=start.month, + day=start.day, + plot_name=key + ) + plt.savefig(plot_path) + plt.show() diff --git a/src/pipeline/app.py b/src/pipeline/app.py index c3a15b1..407b4e0 100644 --- a/src/pipeline/app.py +++ b/src/pipeline/app.py @@ -4,6 +4,7 @@ from stages.search import GeometrySearch from stages.audio import RetrieveAudio, WriteAudio, WriteSiftedAudio from stages.sift import Butterworth +from stages.classify import WhaleClassifier from config import load_pipeline_config config = load_pipeline_config() @@ -18,27 +19,28 @@ def run(): } with beam.Pipeline(options=pipeline_options) as p: - input_data = p | "Create Input" >> beam.Create([args]) - search_results = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch()) + input_data = p | "Create Input" >> beam.Create([args]) + search_output = input_data | "Run Geometry Search" >> beam.ParDo(GeometrySearch()) - audio_results = search_results | "Retrieve Audio" >> beam.ParDo(RetrieveAudio()) - audio_files = audio_results | "Store Audio (temp)" >> beam.ParDo(WriteAudio()) + audio_output = search_output | "Retrieve Audio" >> beam.ParDo(RetrieveAudio()) + audio_files = audio_output | "Store Audio (temp)" >> beam.ParDo(WriteAudio()) - sifted_audio = audio_results | "Sift Audio" >> Butterworth() + sifted_audio = audio_output | "Sift Audio" >> Butterworth() sifted_audio_files = sifted_audio | "Store Sifted Audio" >> beam.ParDo(WriteSiftedAudio("butterworth")) - # For debugging, you can write the output to a text file - # audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt') - # search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt') + classifications = sifted_audio | "Classify Audio" >> WhaleClassifier(config) - # classified_audio = filtered_audio | "Classify Audio" >> ClassifyAudio() - # # Post-process the labels # postprocessed_labels = classified_audio | "Postprocess Labels" >> PostprocessLabels() # Output results # postprocessed_labels | "Write Results" >> beam.io.WriteToText("output.txt") + # For debugging, you can write the output to a text file + # audio_files | "Write Audio Output" >> beam.io.WriteToText('audio_files.txt') + # search_results | "Write Search Output" >> beam.io.WriteToText('search_results.txt') + + if __name__ == "__main__": run() diff --git a/src/pipeline/config.yaml b/src/pipeline/config.yaml index 93d7984..ab1763e 100644 --- a/src/pipeline/config.yaml +++ b/src/pipeline/config.yaml @@ -36,6 +36,7 @@ pipeline: output_path_template: "data/audio/{sift}/{year}/{month:02}/{filename}" max_duration: 600 # seconds plot: true + show_plot: false plot_path_template: "data/plots/{sift}/{year}/{month:02}/{day:02}/{plot_name}.png" window_size: 512 @@ -54,6 +55,7 @@ pipeline: model_sample_rate: 10000 plot_path_template: "data/plots/results/{year}/{month:02}/{plot_name}.png" url: https://tfhub.dev/google/humpback_whale/1 + model_url: "http://127.0.0.1:5000/predict" postprocess: min_gap: 60 # 1 minute diff --git a/src/pipeline/stages/classify.py b/src/pipeline/stages/classify.py index 38ed0cb..3a73071 100644 --- a/src/pipeline/stages/classify.py +++ b/src/pipeline/stages/classify.py @@ -1,34 +1,30 @@ -from apache_beam.io import filesystems +import apache_beam as beam + from datetime import datetime +from typing import Dict, Any +from types import SimpleNamespace -import apache_beam as beam -import io import librosa import logging -import math -import matplotlib.pyplot as plt import numpy as np -import os -import scipy -import time -import tensorflow_hub as hub -import tensorflow as tf -from config import load_pipeline_config +import requests +import math -config = load_pipeline_config() +logging.getLogger().setLevel(logging.INFO) class BaseClassifier(beam.PTransform): name = "BaseClassifier" - def __init__(self): + def __init__(self, config: Dict[str, Any]): self.source_sample_rate = config.audio.source_sample_rate - self.model_sample_rate = config.classify.model_sample_rate - - self.model_path = config.classify.model_path + + self.batch_duration = config.classify.batch_duration + self.model_sample_rate = config.classify.model_sample_rate + self.model_url = config.classify.model_url def _preprocess(self, pcoll): signal, start, end, encounter_ids = pcoll @@ -36,16 +32,22 @@ def _preprocess(self, pcoll): # Resample signal = self._resample(signal) + logging.info(f"Resampled signal shape: {signal.shape}") + + # Expand final dimension + signal = np.expand_dims(signal, axis=1) + logging.info(f"Expanded signal shape: {signal.shape}") - batch_samples = self.batch_duration * self.sample_rate + # Split signal into batches (if necessary) + batch_samples = self.batch_duration * self.model_sample_rate if signal.shape[0] > batch_samples: - logging.debug(f"Signal size exceeds max sample size {batch_samples}.") + logging.info(f"Signal size exceeds max sample size {batch_samples}.") split_indices = [batch_samples*(i+1) for i in range(math.floor(signal.shape[0] / batch_samples))] signal_batches = np.array_split(signal, split_indices) - logging.debug(f"Split signal into {len(signal_batches)} batches of size {batch_samples}.") - logging.debug(f"Size fo final batch {len(signal_batches[1])}") + logging.info(f"Split signal into {len(signal_batches)} batches of size {batch_samples}.") + logging.info(f"Size fo final batch {len(signal_batches[1])}") for batch in signal_batches: yield (key, batch) @@ -58,6 +60,7 @@ def _build_key( end_time: datetime, encounter_ids: list, ): + # TODO: Refactor this to a common place accessible in all modules that use key start_str = start_time.strftime('%Y%m%dT%H%M%S') end_str = end_time.strftime('%H%M%S') encounter_str = "_".join(encounter_ids) @@ -66,13 +69,8 @@ def _build_key( def _postprocess(self, pcoll): return pcoll - def _get_model(self): - model = hub.load(self.model_path) - return model - def _resample(self, signal): - logging.info( - f"Resampling signal from {self.source_sample_rate} to {self.model_sample_rate}") + logging.info(f"Resampling signal from {self.source_sample_rate} to {self.model_sample_rate}") return librosa.resample( signal, orig_sr=self.source_sample_rate, @@ -80,113 +78,106 @@ def _resample(self, signal): ) -class GoogleHumpbackWhaleClassifier(BaseClassifier): - """ - Model docs: https://tfhub.dev/google/humpback_whale/1 - """ - name = "GoogleHumpbackWhaleClassifier" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.model = self._get_model() - self.score_fn = self.model.signatures['score'] - self.metadata_fn = self.model.signatures['metadata'] - - +class WhaleClassifier(BaseClassifier): def expand(self, pcoll): - return ( - pcoll - | "Preprocess" >> beam.Map(self._preprocess) - | "Classify" >> beam.Map(self._classify) - | "Postprocess" >> beam.Map(self._postprocess) + key_batch = pcoll | "Preprocess" >> beam.ParDo(self._preprocess) + batched_outputs = key_batch | "Classify" >> beam.ParDo(InferenceClient(self.model_url)) + grouped_outputs = batched_outputs | "Combine batched_outputs" >> beam.CombinePerKey(ListCombine()) + outputs = pcoll | "Postprocess" >> beam.Map( + self._postprocess, + grouped_outputs=beam.pvalue.AsDict(grouped_outputs), ) + return outputs + def _postprocess(self, pcoll, grouped_outputs): + signal, start, end, encounter_ids = pcoll + key = self._build_key(start, end, encounter_ids) + output = grouped_outputs.get(key, []) + + logging.info(f"Postprocessing {key} with signal {len(signal)} and output {len(output)}") - def _classify(self, pcoll, ): - key, signal = pcoll + return signal, start, end, encounter_ids, output - start_classify = time.time() - # We specify a 1-sec score resolution: - context_step_samples = tf.cast(self.model_sample_rate, tf.int64) +class InferenceClient(beam.DoFn): + def __init__(self, model_url: str): + self.model_url = model_url - logging.info(f'\n==> Applying model ...') - logging.debug(f' inital input: len(signal_10kHz) = {len(signal)}') + def process(self, element): + key, batch = element - waveform1 = np.expand_dims(signal, axis=1) - waveform_exp = tf.expand_dims(waveform1, 0) # makes a batch of size 1 - logging.debug(f" final input: waveform_exp.shape = {waveform_exp.shape}") + # skip empty batches + if len(batch) == 0: + return {"key": key, "predictions": []} - signal_scores = self.score_fn( - waveform=waveform_exp, - context_step_samples=context_step_samples - ) - score_values = signal_scores['scores'].numpy()[0] - logging.info(f'==> Model applied. Obtained {len(score_values)} score_values') - logging.info(f'==> Elapsed time: {time.time() - start_classify} seconds') + if isinstance(batch, np.ndarray): + batch = batch.tolist() - return (key, score_values) - - def _plot_spectrogram_scipy(self, signal, epsilon = 1e-15): - # Compute spectrogram: - w = scipy.signal.get_window('hann', self.sample_rate) - f, t, psd = scipy.signal.spectrogram( - signal, # TODO make sure this is resampled signal - self.model_sample_rate, - nperseg=self.model_sample_rate, - noverlap=0, - window=w, - nfft=self.model_sample_rate, - ) - psd = 10*np.log10(psd+epsilon) - self.hydrophone_sensitivity - - # Plot spectrogram: - fig = plt.figure(figsize=(20, round(20/3))) # 3:1 aspect ratio - plt.imshow( - psd, - aspect='auto', - origin='lower', - vmin=30, - vmax=90, - cmap='Blues', - ) - plt.yscale('log') - y_max = self.model_sample_rate / 2 - plt.ylim(10, y_max) + data = { + "key": key, + "batch": batch, + } + + response = requests.post(self.model_url, json=data) + response.raise_for_status() - plt.colorbar() + yield response.json() - plt.xlabel('Seconds') - plt.ylabel('Frequency (Hz)') - plt.title(f'Calibrated spectrum levels, 16 {self.sample_rate / 1000.0} kHz data') +class ListCombine(beam.CombineFn): + name = "ListCombine" + # TODO refactor this to a place accessible in both sift.py and here - def _plot_scores(self, pcoll, scores, med_filt_size=None): - audio, start, end, encounter_ids = pcoll - key = self._build_key(start, end, encounter_ids) + def create_accumulator(self): + return [] + + def add_input(self, accumulator, input): + """ + Key is not available in this method, + though inputs are only added to accumulator under correct key. + """ + logging.debug(f"Adding input {input} to {self.name} accumulator.") + if isinstance(input, np.ndarray): + input = input.tolist() + accumulator += input + return accumulator - # repeat last value to also see a step at the end: - scores = np.concatenate((scores, scores[-1:])) - x = range(len(scores)) - plt.step(x, scores, where='post') - plt.plot(x, scores, 'o', color='lightgrey', markersize=9) - - plt.grid(axis='x', color='0.95') - plt.xlim(xmin=0, xmax=len(scores) - 1) - plt.ylabel('Model Score') - plt.xlabel('Seconds') - - if med_filt_size is not None: - scores_int = [int(s[0]*1000) for s in scores] - meds_int = scipy.signal.medfilt(scores_int, kernel_size=med_filt_size) - meds = [m/1000. for m in meds_int] - plt.plot(x, meds, 'p', color='black', markersize=9) - - plot_path = config.classify.plot_path_template.format( - year=start.year, - month=start.month, - day=start.day, - plot_name=key + def merge_accumulators(self, accumulators): + return [item for sublist in accumulators for item in sublist] + + def extract_output(self, accumulator): + return accumulator + + + +def sample_run(): + signal = np.load("data/audio/butterworth/2016/12/20161221T004930-005030-9182.npy") + data = ( + signal, + datetime.strptime("2016-12-21T00:49:30", "%Y-%m-%dT%H:%M:%S"), + datetime.strptime("2016-12-21T00:50:30", "%Y-%m-%dT%H:%M:%S"), + ["9182"] + ) + + # simulate config (avoids local import) + config = SimpleNamespace( + audio = SimpleNamespace(source_sample_rate=16_000), + classify = SimpleNamespace( + batch_duration=30, # seconds + model_sample_rate=10_000, + model_url="http://127.0.0.1:5000/predict" + ), + ) + + with beam.Pipeline() as p: + output = ( + p + | beam.Create([(data)]) + | WhaleClassifier(config) ) - plt.savefig(plot_path) - plt.show() + logging.info(output) + + +if __name__ == "__main__": + sample_run() + \ No newline at end of file diff --git a/src/pipeline/stages/sift.py b/src/pipeline/stages/sift.py index 71bef62..ffcb808 100644 --- a/src/pipeline/stages/sift.py +++ b/src/pipeline/stages/sift.py @@ -35,6 +35,7 @@ class BaseSift(beam.PTransform): # plot params plot = config.sift.plot plot_path_template = config.sift.plot_path_template + show_plot = config.sift.show_plot def _build_key( self, @@ -146,7 +147,7 @@ def _plot_signal_detections(self, pcoll, min_max_detections, all_detections, par title += f"Encounters: {encounter_ids}" plt.title(title) plt.savefig(plot_path) - plt.show() + plt.show() if self.show_plot else plt.close() class Butterworth(BaseSift):