Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work on a take api #32

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ serde = ["dep:serde"]
[dependencies]
anyhow = "1.0.86"
ndarray = "0.16"
ort = "=2.0.0-rc.6"
ort = "=2.0.0-rc.9"
rubato = { version = "0.16.0", optional = true}
serde = { version = "1.0.208", features = ["derive"], optional = true }
thiserror = "1.0.64"
Expand Down
201 changes: 173 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pub use crate::audio_resampler::resample_pcm;
pub use crate::errors::VadError;
use anyhow::{bail, Context, Result};
use ndarray::{Array1, Array2, Array3, ArrayBase, Ix1, Ix3, OwnedRepr};
use ort::{GraphOptimizationLevel, Session};
use ort::session::{builder::GraphOptimizationLevel, Session};
use std::ops::Range;
use std::path::Path;
use std::time::Duration;
Expand Down Expand Up @@ -44,8 +44,6 @@ pub struct VadSession {

/// Current start of the speech in milliseconds
speech_start_ms: Option<usize>,
/// Current end of the speech in milliseconds
speech_end_ms: Option<usize>,

/// Cached current active samples
cached_active_speech: Vec<f32>,
Expand Down Expand Up @@ -140,7 +138,6 @@ impl VadSession {
deleted_samples: 0,
silent_samples: 0,
speech_start_ms: None,
speech_end_ms: None,
cached_active_speech: vec![],
})
}
Expand Down Expand Up @@ -218,7 +215,7 @@ impl VadSession {
Ok(transitions)
}

pub fn forward(&mut self, input: Vec<f32>) -> Result<ort::Value> {
pub fn forward(&mut self, input: Vec<f32>) -> Result<ort::value::Value> {
let samples = input.len();
let audio_tensor = Array2::from_shape_vec((1, samples), input)?;
let mut result = self.model.run(ort::inputs![
Expand Down Expand Up @@ -307,7 +304,6 @@ impl VadSession {
timestamp_ms: start_ms,
});
self.speech_start_ms = Some(start_ms);
self.speech_end_ms = None;
}

if prob < self.config.negative_speech_threshold {
Expand All @@ -319,24 +315,22 @@ impl VadSession {
- self.silent_samples)
/ (self.config.sample_rate / 1000);

self.speech_end_ms = Some(speech_end_ms);
vad_change = Some(VadTransition::SpeechEnd {
start_timestamp_ms: start_ms,
end_timestamp_ms: speech_end_ms,
samples: self.get_current_speech().to_vec(),
});

// Need to delete the current speech samples from internal buffer to prevent OOM.
assert!(self.speech_start_ms.is_some() && self.speech_end_ms.is_some());
assert!(self.speech_start_ms.is_some());
self.cached_active_speech = self.get_current_speech().to_vec();
let speech_end_idx =
self.speech_end_ms.unwrap() * self.config.sample_rate / 1000
- self.deleted_samples;
let speech_end_idx = self.unchecked_duration_to_index(
Duration::from_millis(speech_end_ms as u64),
);
let to_delete_idx = 0..(speech_end_idx + 1);
self.session_audio.drain(to_delete_idx);
self.deleted_samples += speech_end_idx + 1;
self.speech_start_ms = None;
self.speech_end_ms = None;
}
self.state = VadState::Silence
}
Expand All @@ -349,13 +343,79 @@ impl VadSession {
Ok(vad_change)
}

/// This will remove audio in the buffer until a duration and panic if it exceeds the duration.
/// This won't touch the active speech cache or the VAD state. It's intended usage is if the
/// current speech buffer is too long and you want to remove some for processing and not have
/// it be re-processed or considered again.
///
/// If there is no remaining audio within the range this will return an empty vector.
/// Additionally, if the speech is just in the cached last-segment it won't take from that
/// (though this could be done in the future).
///
/// # Panics
///
/// If the time given is beyond the range of the current session this will panic.
pub fn take_until(&mut self, end: Duration) -> Vec<f32> {
if end > self.session_time() {
panic!(
"{}ms is greater than session time of {}ms",
end.as_millis(),
self.session_time().as_millis()
);
} else {
match self.duration_to_index(end) {
Some(s) => {
let mut returned_audio = self.session_audio.split_off(s);
std::mem::swap(&mut self.session_audio, &mut returned_audio);
self.deleted_samples += returned_audio.len();
if matches!(self.state, VadState::Speech { .. }) {
if let Some(start_ms) = self.speech_start_ms.take() {
if start_ms < end.as_millis() as usize {
self.speech_start_ms = Some(end.as_millis() as usize);
}
}
}
returned_audio
}
None => vec![],
}
}
}

/// Returns whether the vad current believes the audio to contain speech
pub fn is_speaking(&self) -> bool {
matches!(self.state, VadState::Speech {
redemption_passed, ..
} if redemption_passed)
}

/// Takes a duration and converts it to an index if it's within the current session audio or
/// `None` if it's not.
///
/// # Panics
///
/// If this is out of the range it will panic
fn unchecked_duration_to_index(&self, duration: Duration) -> usize {
match self.duration_to_index(duration) {
Some(idx) => idx,
None => panic!(
"Duration {}ms is outside of session audio range",
duration.as_millis()
),
}
}

/// Takes a duration and converts it to an index if it's within the current session audio or
/// `None` if it's not.
fn duration_to_index(&self, duration: Duration) -> Option<usize> {
let unadjusted_index = duration.as_millis() as usize * (self.config.sample_rate / 1000);
if unadjusted_index < self.deleted_samples {
None
} else {
Some(unadjusted_index - self.deleted_samples)
}
}

/// Gets the speech within a given range of milliseconds. You can use previous speech start/end
/// event pairs to get speech windows before the current speech using this API. If end is
/// `None` this will return from the start point to the end of the buffer.
Expand All @@ -365,12 +425,11 @@ impl VadSession {
/// If the range is out of bounds of the speech buffer this method will panic due to an
/// assertion failure.
pub fn get_speech(&self, start_ms: usize, end_ms: Option<usize>) -> &[f32] {
let speech_start_idx = start_ms * (self.config.sample_rate / 1000) - self.deleted_samples;
assert!(speech_start_idx < self.session_audio.len());
let speech_start_idx =
self.unchecked_duration_to_index(Duration::from_millis(start_ms as u64));
if let Some(speech_end) = end_ms {
let speech_end_idx =
speech_end * (self.config.sample_rate / 1000) - self.deleted_samples;
assert!(speech_end_idx < self.session_audio.len());
self.unchecked_duration_to_index(Duration::from_millis(speech_end as u64));
&self.session_audio[speech_start_idx..speech_end_idx]
} else {
&self.session_audio[speech_start_idx..]
Expand All @@ -383,7 +442,7 @@ impl VadSession {
/// been applied.
pub fn get_current_speech(&self) -> &[f32] {
if let Some(speech_start) = self.speech_start_ms {
self.get_speech(speech_start, self.speech_end_ms)
self.get_speech(speech_start, None)
} else {
&self.cached_active_speech
}
Expand Down Expand Up @@ -425,7 +484,6 @@ impl VadSession {
self.h_tensor = Array3::<f32>::zeros((2, 1, 64));
self.c_tensor = Array3::<f32>::zeros((2, 1, 64));
self.speech_start_ms = None;
self.speech_end_ms = None;
self.silent_samples = 0;
self.state = VadState::Silence;
}
Expand Down Expand Up @@ -503,13 +561,28 @@ impl VadConfig {
#[cfg(test)]
mod tests {
use super::*;
use tracing_test::traced_test;

/// Feed only silence into the network and ensure that `get_current_speech` returns an empty
/// slice
#[test]
#[traced_test]
fn only_silence_get_speech() {
let mut session = VadSession::new(VadConfig::default()).unwrap();
let short_audio = vec![0.0; 1000];

session.process(&short_audio).unwrap();

assert_eq!(session.get_current_speech(), &[]);
}

/// Basic smoke test that the model loads correctly and we haven't committed rubbish to the
/// repo.
#[test]
#[traced_test]
fn model_loads() {
let _sesion = VadSession::new(VadConfig::default()).unwrap();
let _sesion =
let _session = VadSession::new(VadConfig::default()).unwrap();
let _session =
VadSession::new_from_path("models/silero_vad.onnx", VadConfig::default()).unwrap();
}

Expand All @@ -518,6 +591,7 @@ mod tests {
/// short inference in the internal inference call bubbles up an error but when using the
/// public API no error is presented.
#[test]
#[traced_test]
fn short_audio_handling() {
let mut session = VadSession::new(VadConfig::default()).unwrap();

Expand All @@ -532,6 +606,7 @@ mod tests {
/// Check that a long enough packet of just zeros gets an inference and it doesn't flag as
/// transitioning to speech
#[test]
#[traced_test]
fn silence_handling() {
let mut session = VadSession::new(VadConfig::default()).unwrap();
let silence = vec![0.0; 30 * 16]; // 30ms of silence
Expand All @@ -542,6 +617,7 @@ mod tests {

/// We only allow for 8khz and 16khz audio.
#[test]
#[traced_test]
fn reject_invalid_sample_rate() {
let mut config = VadConfig::default();
config.sample_rate = 16000;
Expand All @@ -560,6 +636,7 @@ mod tests {
/// Just a sanity test of speech duration to make sure the calculation seems roughly right in
/// terms of number of samples, sample rate and taking into account the speech starts/ends.
#[test]
#[traced_test]
fn simple_speech_duration() {
let mut config = VadConfig::default();
config.sample_rate = 8000;
Expand All @@ -571,19 +648,14 @@ mod tests {
session.speech_start_ms = Some(10);
assert_eq!(session.current_speech_duration(), Duration::from_secs(2));

session.speech_end_ms = Some(1010);
assert_eq!(session.current_speech_duration(), Duration::from_secs(1));

session.config.sample_rate = 16000;
session.speech_end_ms = Some(510);
assert_eq!(
session.current_speech_duration(),
Duration::from_millis(500)
);
session.session_audio.resize(16160, 0.0);
assert_eq!(session.current_speech_duration(), Duration::from_secs(1));
}

/// The provided audio sample must be in the range -1.0 to 1.0
#[test]
#[traced_test]
fn audio_sample_range() {
let config = VadConfig::default();

Expand All @@ -601,4 +673,77 @@ mod tests {
VadError::InvalidData
));
}

/// Apply some audio with speech in and ensure that the take API works as expected
#[test]
#[traced_test]
fn taking_audio() {
let samples: Vec<f32> = hound::WavReader::open("tests/audio/sample_4.wav")
.unwrap()
.into_samples()
.map(|x| {
let modified = x.unwrap_or(0i16) as f32 / (i16::MAX as f32);
modified.clamp(-1.0, 1.0)
})
.collect();

let config = VadConfig::default();
let mut session = VadSession::new(config.clone()).unwrap();

let chunk_size = 480; // 30ms
let max_chunks = samples.len() / chunk_size;

let mut start_time = 0;

for i in 0..max_chunks {
let start = i * chunk_size;
let end = (start + chunk_size).min(samples.len());
let trans = session.process(&samples[start..end]).unwrap();

for transition in &trans {
match transition {
VadTransition::SpeechStart { timestamp_ms } => {
start_time = *timestamp_ms as u64;
}
_ => panic!("Oh no it's over"),
}
}

if session.is_speaking()
&& (session.session_time() - Duration::from_millis(start_time))
>= Duration::from_millis(120)
{
break;
}
}
assert!(session.is_speaking(), "never found speech");

let current_untaken = session.get_current_speech().to_vec();

let taken = session.take_until(Duration::from_millis(start_time + 60));

assert!(session.current_speech_samples() < current_untaken.len());
assert_eq!(
session.current_speech_samples() + taken.len(),
current_untaken.len()
);

assert!(session
.take_until(Duration::from_millis(start_time))
.is_empty());
}

/// If we take past our boundary we panic!
#[test]
#[traced_test]
#[should_panic]
fn excessive_take() {
let config = VadConfig::default();
let mut session = VadSession::new(config.clone()).unwrap();

let silence = vec![0.0; 16000];
let _ = session.process(&silence);

session.take_until(Duration::from_millis(1001));
}
}
Loading