Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tutorial stuff #3

Merged
merged 16 commits into from
Jan 31, 2024
21 changes: 21 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Checks

on:
push:
branches: ["master"]
pull_request:
jobs:
build_and_test:
name: Rust project
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- name: Cargo build
run: cargo build
- name: Cargo test
run: cargo test
- name: Rustfmt
run: cargo fmt --check
6 changes: 5 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ edition = "2021"

[features]
default = ["ndarray-linalg/intel-mkl"]
# Do not use me unless you want a lot of npy files dumped in your CWD
debug_dump = ["ndarray-npy"]

[dependencies]
anyhow = "1.0.71"
lbfgsb = "0.1.0"
ndarray = { version = "0.15.6", features = ["approx", "rayon"] }
ndarray-linalg = "0.16.0"
ndarray-npy = { version = "0.8.1", optional = true }
ndarray-rand = "0.14.0"
ndarray-stats = "0.5.1"
rand = "0.8.5"
rand_isaac = "0.3.0"
realfft = "3.3.0"
thiserror = "1.0.40"
tracing = "0.1.37"

[dev-dependencies]
approx = "0.5.1"
clap = { version = "4.4.7", features = ["derive"] }
float-cmp = "0.9.0"
hound = "3.5.1"
ndarray-npy = "0.8.1"
rand_isaac = "0.3.0"
90 changes: 63 additions & 27 deletions examples/run_griffin_lim.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,75 @@
use clap::Parser;
use griffin_lim::GriffinLim;
use hound::{SampleFormat, WavSpec, WavWriter};
use ndarray::prelude::*;
use ndarray_npy::read_npy;
use std::error::Error;
use std::path::PathBuf;
use std::result::Result;
use std::time::Instant;

#[derive(Parser, Debug)]
pub struct Args {
#[clap(long, short, default_value = "example_spectrogram.npy")]
input: String,
#[clap(short, long, default_value = "output.wav")]
output: PathBuf,
#[clap(long, default_value = "22050")]
sample_rate: u32,
#[clap(long, default_value = "1024")]
ffts: usize,
#[clap(long, default_value = "80")]
mels: usize,
#[clap(long, default_value = "256")]
hop_length: usize,
#[clap(long, default_value = "8000.0")]
max_frequency: f32,
#[clap(long, default_value = "1.7")]
power: f32,
#[clap(long, default_value = "10")]
iters: usize,
}

fn main() -> Result<(), Box<dyn Error>> {
let spectrogram: Array2<f32> = read_npy("resources/example_spectrogram.npy")?;

let mel_basis = griffin_lim::mel::create_mel_filter_bank(22050.0, 1024, 80, 0.0, Some(8000.0));

for iter in [0, 1, 2, 5, 10] {
let timer = Instant::now();
let mut vocoder = GriffinLim::new(mel_basis.clone(), 1024 - 256, 1.5, 1, 0.99)?;
vocoder.iter = iter;
let audio = vocoder.infer(&spectrogram)?;
let duration = Instant::now().duration_since(timer);
let rtf = duration.as_secs_f32() / (audio.len() as f32 / 22050_f32);
println!("Iterations: {}, rtf: {}", iter, rtf);
let spec = WavSpec {
channels: 1,
sample_rate: 22050,
bits_per_sample: 32,
sample_format: SampleFormat::Float,
};

let mut writer = WavWriter::create(format!("audio_output_griffinlim_{}.wav", iter), spec)?;

for sample in audio {
writer.write_sample(sample)?;
}

writer.finalize()?;
println!("Saved audio_output_griffinlim_{}.wav", iter);
let args = Args::parse();
let spectrogram: Array2<f32> = read_npy(args.input)?;

let mel_basis = griffin_lim::mel::create_mel_filter_bank(
args.sample_rate as f32,
args.ffts,
args.mels,
0.0,
Some(args.max_frequency),
);

let timer = Instant::now();
let vocoder = GriffinLim::new(
mel_basis.clone(),
args.ffts - args.hop_length,
args.power,
args.iters,
0.99,
)?;
let audio = vocoder.infer(&spectrogram)?;
let duration = Instant::now().duration_since(timer);
let rtf = duration.as_secs_f32() / (audio.len() as f32 / args.sample_rate as f32);
println!("Iterations: {}, rtf: {}", args.iters, rtf);

let spec = WavSpec {
channels: 1,
sample_rate: args.sample_rate,
bits_per_sample: 16,
sample_format: SampleFormat::Int,
};

let mut wav_writer = WavWriter::create(&args.output, spec)?;

let mut i16_writer = wav_writer.get_i16_writer(audio.len() as u32);
for sample in &audio {
i16_writer.write_sample((*sample * i16::MAX as f32) as i16);
}
i16_writer.flush()?;

println!("Saved {}", args.output.display());
Ok(())
}
56 changes: 39 additions & 17 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,39 @@ use ndarray::{par_azip, prelude::*, ScalarOperand};
use ndarray_linalg::error::LinalgError;
use ndarray_linalg::svd::SVD;
use ndarray_linalg::{Lapack, Scalar};
#[cfg(feature = "debug_dump")]
use ndarray_npy::WritableElement;
use ndarray_rand::rand_distr::uniform::SampleUniform;
use ndarray_rand::{rand_distr::Uniform, RandomExt};
use ndarray_stats::errors::MinMaxError;
use ndarray_stats::QuantileExt;
use num_traits::{Float, FloatConst, FromPrimitive};
use rand::SeedableRng;
use rand_isaac::isaac64::Isaac64Rng;
use realfft::num_complex::Complex;
use realfft::num_traits;
use realfft::num_traits::AsPrimitive;
use realfft::RealFftPlanner;
use std::fmt::Display;
use tracing::warn;

macro_rules! debug_dump_array {
($file:expr, $array:expr) => {
#[cfg(feature = "debug_dump")]
if let Err(e) = ndarray_npy::write_npy($file, &$array.view()) {
tracing::error!("Failed to write '{:?}': {}", $file, e);
}
};
}

/// Do not use this in any real way. Because we can't cfg on trait bounds and want to ensure the
/// matrices are dump-able via trait bounds we need to remove the cfg from the trait bound and push
/// it up to here. I tried doing this via trait inheritance but this didn't work for the `Complex<T>`
/// bound and this seemed the only reliable way.
#[cfg(not(feature = "debug_dump"))]
pub trait WritableElement {}

#[cfg(not(feature = "debug_dump"))]
impl<T> WritableElement for T {}

pub mod mel;

pub struct GriffinLim {
Expand All @@ -36,6 +55,7 @@ impl GriffinLim {
iter: usize,
momentum: f32,
) -> anyhow::Result<Self> {
debug_dump_array!("mel_basis.npy", mel_basis);
let nfft = 2 * (mel_basis.dim().1 - 1);
if noverlap >= nfft {
bail!(
Expand Down Expand Up @@ -63,6 +83,7 @@ impl GriffinLim {
}

pub fn infer(&self, mel_spec: &Array2<f32>) -> anyhow::Result<Array1<f32>> {
debug_dump_array!("mel_spectrogram.npy", mel_spec);
// mel_basis has dims (nmel, nfft)
// lin_spec has dims (nfft, time)
// mel_spec has dims (nmel, time)
Expand All @@ -71,6 +92,7 @@ impl GriffinLim {

// correct for "power" parameter of mel-spectrogram
lin_spec.mapv_inplace(|x| x.powf(1.0 / self.power));
debug_dump_array!("linear_spectrogram.npy", lin_spec);

let params = Parameters {
momentum: self.momentum,
Expand Down Expand Up @@ -106,7 +128,6 @@ impl GriffinLim {
/// Parameters to provide to the griffin-lim vocoder
pub struct Parameters<T> {
momentum: T,
seed: u64,
iter: usize,
init_random: bool,
}
Expand All @@ -128,7 +149,6 @@ where
pub fn new() -> Self {
Self {
momentum: T::from_f32(0.99).unwrap(),
seed: 42,
iter: 32,
init_random: true,
}
Expand All @@ -145,12 +165,6 @@ where
self
}

/// A random seed to use for initializing the phase.
pub fn seed(mut self, seed: u64) -> Self {
self.seed = seed;
self
}

/// Number of iterations to run - default value is 32.
pub fn iter(mut self, iter: usize) -> Self {
self.iter = iter;
Expand Down Expand Up @@ -273,8 +287,8 @@ pub fn griffin_lim<T>(
noverlap: usize,
) -> anyhow::Result<Array1<T>>
where
T: realfft::FftNum + Float + FloatConst + Display + SampleUniform,
Complex<T>: ScalarOperand,
T: realfft::FftNum + Float + FloatConst + Display + SampleUniform + WritableElement,
Complex<T>: ScalarOperand + WritableElement,
{
griffin_lim_with_params(spectrogram, nfft, noverlap, Parameters::new())
}
Expand All @@ -288,11 +302,10 @@ pub fn griffin_lim_with_params<T>(
params: Parameters<T>,
) -> anyhow::Result<Array1<T>>
where
T: realfft::FftNum + Float + FloatConst + Display + SampleUniform,
Complex<T>: ScalarOperand,
T: realfft::FftNum + Float + FloatConst + Display + SampleUniform + WritableElement,
Complex<T>: ScalarOperand + WritableElement,
{
// set up griffin lim parameters
let mut rng = Isaac64Rng::seed_from_u64(params.seed);
if params.momentum > T::one() || params.momentum < T::zero() {
bail!("Momentum is {}, should be in range [0,1]", params.momentum);
}
Expand All @@ -302,6 +315,7 @@ where

// Initialise estimate
let mut estimate = if params.init_random {
let mut rng = rand::thread_rng();
let mut angles = Array2::<T>::random_using(
spectrogram.raw_dim(),
Uniform::from(-T::PI()..T::PI()),
Expand All @@ -314,8 +328,11 @@ where
} else {
spectrogram.clone()
};
let mut _est_i = 1;
debug_dump_array!("estimate_spec_0.npy", estimate);

// TODO: Pre-allocate inverse and rebuilt and use `.assign` instead of `=`
// this requires some fighting with the borow checker
// this requires some fighting with the borrow checker
let mut inverse: Array1<T>;
let mut rebuilt: Array2<Complex<T>>;
let mut tprev: Option<Array2<Complex<T>>> = None;
Expand All @@ -335,12 +352,15 @@ where
} else {
tprev = Some(rebuilt);
}
// Get angles from estimate and apply to magnitueds
// Get angles from estimate and apply to magnitudes
let eps = T::min_positive_value();
// get angles from new estimate
estimate.mapv_inplace(|x| x / (x.norm() + eps));
// enforce magnitudes
estimate.assign(&(&estimate * &spectrogram));

debug_dump_array!(format!("estimate_spec_{}.npy", _est_i), estimate);
_est_i += 1;
}
let mut signal = istft(&estimate, &window, planner, nfft, noverlap);
let norm = T::from(nfft).unwrap();
Expand Down Expand Up @@ -490,6 +510,8 @@ where
mod tests {
use float_cmp::assert_approx_eq;
use ndarray_npy::read_npy;
use rand::SeedableRng;
use rand_isaac::isaac64::Isaac64Rng;

use super::*;

Expand Down
1 change: 1 addition & 0 deletions tutorial/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
estimate_spec*
Loading
Loading