diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml new file mode 100644 index 0000000..cce2dad --- /dev/null +++ b/.github/workflows/checks.yml @@ -0,0 +1,21 @@ +name: Checks + +on: + push: + branches: ["master"] + pull_request: +jobs: + build_and_test: + name: Rust project + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - name: Cargo build + run: cargo build + - name: Cargo test + run: cargo test + - name: Rustfmt + run: cargo fmt --check diff --git a/Cargo.toml b/Cargo.toml index 5d560e0..a6a10d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,22 +5,26 @@ edition = "2021" [features] default = ["ndarray-linalg/intel-mkl"] +# Do not use me unless you want a lot of npy files dumped in your CWD +debug_dump = ["ndarray-npy"] [dependencies] anyhow = "1.0.71" lbfgsb = "0.1.0" ndarray = { version = "0.15.6", features = ["approx", "rayon"] } ndarray-linalg = "0.16.0" +ndarray-npy = { version = "0.8.1", optional = true } ndarray-rand = "0.14.0" ndarray-stats = "0.5.1" rand = "0.8.5" -rand_isaac = "0.3.0" realfft = "3.3.0" thiserror = "1.0.40" tracing = "0.1.37" [dev-dependencies] approx = "0.5.1" +clap = { version = "4.4.7", features = ["derive"] } float-cmp = "0.9.0" hound = "3.5.1" ndarray-npy = "0.8.1" +rand_isaac = "0.3.0" diff --git a/README.md b/README.md index 2a74075..a4d3a5a 100644 --- a/README.md +++ b/README.md @@ -9,4 +9,5 @@ audio. To learn more about this algorithm: +* [Our own tutorial](./tutorial/README.md) * [Papers with Code](https://paperswithcode.com/method/griffin-lim-algorithm) diff --git a/examples/run_griffin_lim.rs b/examples/run_griffin_lim.rs index f15d9bd..12cfbf7 100644 --- a/examples/run_griffin_lim.rs +++ b/examples/run_griffin_lim.rs @@ -1,39 +1,75 @@ +use clap::Parser; use griffin_lim::GriffinLim; use hound::{SampleFormat, WavSpec, WavWriter}; use ndarray::prelude::*; use ndarray_npy::read_npy; use std::error::Error; +use std::path::PathBuf; use std::result::Result; use std::time::Instant; +#[derive(Parser, Debug)] +pub struct Args { + #[clap(long, short, default_value = "example_spectrogram.npy")] + input: String, + #[clap(short, long, default_value = "output.wav")] + output: PathBuf, + #[clap(long, default_value = "22050")] + sample_rate: u32, + #[clap(long, default_value = "1024")] + ffts: usize, + #[clap(long, default_value = "80")] + mels: usize, + #[clap(long, default_value = "256")] + hop_length: usize, + #[clap(long, default_value = "8000.0")] + max_frequency: f32, + #[clap(long, default_value = "1.7")] + power: f32, + #[clap(long, default_value = "10")] + iters: usize, +} + fn main() -> Result<(), Box> { - let spectrogram: Array2 = read_npy("resources/example_spectrogram.npy")?; - - let mel_basis = griffin_lim::mel::create_mel_filter_bank(22050.0, 1024, 80, 0.0, Some(8000.0)); - - for iter in [0, 1, 2, 5, 10] { - let timer = Instant::now(); - let mut vocoder = GriffinLim::new(mel_basis.clone(), 1024 - 256, 1.5, 1, 0.99)?; - vocoder.iter = iter; - let audio = vocoder.infer(&spectrogram)?; - let duration = Instant::now().duration_since(timer); - let rtf = duration.as_secs_f32() / (audio.len() as f32 / 22050_f32); - println!("Iterations: {}, rtf: {}", iter, rtf); - let spec = WavSpec { - channels: 1, - sample_rate: 22050, - bits_per_sample: 32, - sample_format: SampleFormat::Float, - }; - - let mut writer = WavWriter::create(format!("audio_output_griffinlim_{}.wav", iter), spec)?; - - for sample in audio { - writer.write_sample(sample)?; - } - - writer.finalize()?; - println!("Saved audio_output_griffinlim_{}.wav", iter); + let args = Args::parse(); + let spectrogram: Array2 = read_npy(args.input)?; + + let mel_basis = griffin_lim::mel::create_mel_filter_bank( + args.sample_rate as f32, + args.ffts, + args.mels, + 0.0, + Some(args.max_frequency), + ); + + let timer = Instant::now(); + let vocoder = GriffinLim::new( + mel_basis.clone(), + args.ffts - args.hop_length, + args.power, + args.iters, + 0.99, + )?; + let audio = vocoder.infer(&spectrogram)?; + let duration = Instant::now().duration_since(timer); + let rtf = duration.as_secs_f32() / (audio.len() as f32 / args.sample_rate as f32); + println!("Iterations: {}, rtf: {}", args.iters, rtf); + + let spec = WavSpec { + channels: 1, + sample_rate: args.sample_rate, + bits_per_sample: 16, + sample_format: SampleFormat::Int, + }; + + let mut wav_writer = WavWriter::create(&args.output, spec)?; + + let mut i16_writer = wav_writer.get_i16_writer(audio.len() as u32); + for sample in &audio { + i16_writer.write_sample((*sample * i16::MAX as f32) as i16); } + i16_writer.flush()?; + + println!("Saved {}", args.output.display()); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 47b3caa..5aac7e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,13 @@ use ndarray::{par_azip, prelude::*, ScalarOperand}; use ndarray_linalg::error::LinalgError; use ndarray_linalg::svd::SVD; use ndarray_linalg::{Lapack, Scalar}; +#[cfg(feature = "debug_dump")] +use ndarray_npy::WritableElement; use ndarray_rand::rand_distr::uniform::SampleUniform; use ndarray_rand::{rand_distr::Uniform, RandomExt}; use ndarray_stats::errors::MinMaxError; use ndarray_stats::QuantileExt; use num_traits::{Float, FloatConst, FromPrimitive}; -use rand::SeedableRng; -use rand_isaac::isaac64::Isaac64Rng; use realfft::num_complex::Complex; use realfft::num_traits; use realfft::num_traits::AsPrimitive; @@ -18,6 +18,25 @@ use realfft::RealFftPlanner; use std::fmt::Display; use tracing::warn; +macro_rules! debug_dump_array { + ($file:expr, $array:expr) => { + #[cfg(feature = "debug_dump")] + if let Err(e) = ndarray_npy::write_npy($file, &$array.view()) { + tracing::error!("Failed to write '{:?}': {}", $file, e); + } + }; +} + +/// Do not use this in any real way. Because we can't cfg on trait bounds and want to ensure the +/// matrices are dump-able via trait bounds we need to remove the cfg from the trait bound and push +/// it up to here. I tried doing this via trait inheritance but this didn't work for the `Complex` +/// bound and this seemed the only reliable way. +#[cfg(not(feature = "debug_dump"))] +pub trait WritableElement {} + +#[cfg(not(feature = "debug_dump"))] +impl WritableElement for T {} + pub mod mel; pub struct GriffinLim { @@ -36,6 +55,7 @@ impl GriffinLim { iter: usize, momentum: f32, ) -> anyhow::Result { + debug_dump_array!("mel_basis.npy", mel_basis); let nfft = 2 * (mel_basis.dim().1 - 1); if noverlap >= nfft { bail!( @@ -63,6 +83,7 @@ impl GriffinLim { } pub fn infer(&self, mel_spec: &Array2) -> anyhow::Result> { + debug_dump_array!("mel_spectrogram.npy", mel_spec); // mel_basis has dims (nmel, nfft) // lin_spec has dims (nfft, time) // mel_spec has dims (nmel, time) @@ -71,6 +92,7 @@ impl GriffinLim { // correct for "power" parameter of mel-spectrogram lin_spec.mapv_inplace(|x| x.powf(1.0 / self.power)); + debug_dump_array!("linear_spectrogram.npy", lin_spec); let params = Parameters { momentum: self.momentum, @@ -106,7 +128,6 @@ impl GriffinLim { /// Parameters to provide to the griffin-lim vocoder pub struct Parameters { momentum: T, - seed: u64, iter: usize, init_random: bool, } @@ -128,7 +149,6 @@ where pub fn new() -> Self { Self { momentum: T::from_f32(0.99).unwrap(), - seed: 42, iter: 32, init_random: true, } @@ -145,12 +165,6 @@ where self } - /// A random seed to use for initializing the phase. - pub fn seed(mut self, seed: u64) -> Self { - self.seed = seed; - self - } - /// Number of iterations to run - default value is 32. pub fn iter(mut self, iter: usize) -> Self { self.iter = iter; @@ -273,8 +287,8 @@ pub fn griffin_lim( noverlap: usize, ) -> anyhow::Result> where - T: realfft::FftNum + Float + FloatConst + Display + SampleUniform, - Complex: ScalarOperand, + T: realfft::FftNum + Float + FloatConst + Display + SampleUniform + WritableElement, + Complex: ScalarOperand + WritableElement, { griffin_lim_with_params(spectrogram, nfft, noverlap, Parameters::new()) } @@ -288,11 +302,10 @@ pub fn griffin_lim_with_params( params: Parameters, ) -> anyhow::Result> where - T: realfft::FftNum + Float + FloatConst + Display + SampleUniform, - Complex: ScalarOperand, + T: realfft::FftNum + Float + FloatConst + Display + SampleUniform + WritableElement, + Complex: ScalarOperand + WritableElement, { // set up griffin lim parameters - let mut rng = Isaac64Rng::seed_from_u64(params.seed); if params.momentum > T::one() || params.momentum < T::zero() { bail!("Momentum is {}, should be in range [0,1]", params.momentum); } @@ -302,6 +315,7 @@ where // Initialise estimate let mut estimate = if params.init_random { + let mut rng = rand::thread_rng(); let mut angles = Array2::::random_using( spectrogram.raw_dim(), Uniform::from(-T::PI()..T::PI()), @@ -314,8 +328,11 @@ where } else { spectrogram.clone() }; + let mut _est_i = 1; + debug_dump_array!("estimate_spec_0.npy", estimate); + // TODO: Pre-allocate inverse and rebuilt and use `.assign` instead of `=` - // this requires some fighting with the borow checker + // this requires some fighting with the borrow checker let mut inverse: Array1; let mut rebuilt: Array2>; let mut tprev: Option>> = None; @@ -335,12 +352,15 @@ where } else { tprev = Some(rebuilt); } - // Get angles from estimate and apply to magnitueds + // Get angles from estimate and apply to magnitudes let eps = T::min_positive_value(); // get angles from new estimate estimate.mapv_inplace(|x| x / (x.norm() + eps)); // enforce magnitudes estimate.assign(&(&estimate * &spectrogram)); + + debug_dump_array!(format!("estimate_spec_{}.npy", _est_i), estimate); + _est_i += 1; } let mut signal = istft(&estimate, &window, planner, nfft, noverlap); let norm = T::from(nfft).unwrap(); @@ -490,6 +510,8 @@ where mod tests { use float_cmp::assert_approx_eq; use ndarray_npy::read_npy; + use rand::SeedableRng; + use rand_isaac::isaac64::Isaac64Rng; use super::*; diff --git a/tutorial/.gitignore b/tutorial/.gitignore new file mode 100644 index 0000000..aec8767 --- /dev/null +++ b/tutorial/.gitignore @@ -0,0 +1 @@ +estimate_spec* diff --git a/tutorial/README.md b/tutorial/README.md new file mode 100644 index 0000000..ee024c6 --- /dev/null +++ b/tutorial/README.md @@ -0,0 +1,153 @@ +# Stepping Through Griffin-Lim + +## Introduction + +This is an accompaniment to the Rustnation 2024 talk "Creating a Text-To-Speech +System in Rust" and is using a spectrogram of "Hello world" generated using a +version of that TTS system. + +It also uses `examples/run_griffin_lim.rs` to generate the output with the `debug_dump` +feature enabled to grab npy files of intermediate phases. + +## Prerequistite Knowledge + +So before we dive into things there's going to be some foundational maths as with any +signal processing. But this will be kept simpler just to ensure the concepts are understood. + +When we look at a sound wave that's a representation of a signal in the time domain. But we +can also look at spectrograms which are a representation of the signal in the frequency +domain. + +#![image](./resources/trumpet.png) + +At the top we have the magnitude plot, this shows for each time slice and frequency the +amplitude of different sine waves we'd compose to recreate that signal. Below it is the +phase, this is the offset in time we'd apply to those same sine waves. + +The key transform to go from time to frequency domain is a Fourier transform. The Fourier transform +assumes a signal is periodic (repeats), and then decomposes it into summed sine and cosine waves at +different magnitudes and phase offsets. + +A Short Time Fourier Transform (STFT) is a discrete Fourier transform ran on overlapping windows +so we can look at the frequency information of smaller localised slices of time instead +of the frequency information for the entire audio signal. This is more useful in practise than using +a Fourier transform directly. + +## Going from Mel to Linear + +So to start lets look at our spectrogram we're going to be vocoding. + +![image](./resources/mel_spec.png) + +To vocode this we need to reconstruct the phase spectrum so we can do an inverse +transform (ISTFT) to reconstruct the audio. + +The mel spectrogram uses mel frequency bands to compress a linear spectrogram to +a more compressed representation. The reason it does this is to be more representative of human +perception of pitch. It does that by multiplying the linear spectrogram +by a filter bank. So we need to invert that multiplication to go back to a linear +spectrogram. This is an optimisation problem as with matrix multiplication AB != BA. + +So we use a limited-memory [BGFS](https://en.wikipedia.org/wiki/Broyden–Fletcher–Goldfarb–Shanno_algorithm) solver to solve for the inverse of the matrix +multiplication. This is mainly to keep in line with the reference implementation we used. + +So if we construct the mel filter bank we get: + +![image](./resources/mel_basis.png) + +Using these two matrices we invert and get the following linear spectrogram: + +![image](./resources/linear_spec.png) + +It's expected that this will look a lot emptier higher up, the human ear is more +sensitive to lower frequencies so the mel spectrogram conversion throws away a lot +of high frequency data and accentuates the low frequency data. Near the bottom of +the spectrogram we can see a hint of the mel spectrograms form. + +## Estimating the Phase Spectrum + +The above linear spectrum was only the magnitude component, the phase is harder to +predict so often algorithms will generate a magnitude spectrum and then use an +algorithm like Griffin-Lim to reconstruct the phase information. + +So in this crates internals we either initialise a random phase spectrum (default) +or start from zeroes. There's no benefit to starting from zeroes it just makes tests +repeatable. It's typical in algorithms like this to allow users to either set a +random seed or a fixed initialisation to enable reproducable results when used in +research. + +```rust +let mut estimate = if params.init_random { + let mut angles = Array2::::random_using( + spectrogram.raw_dim(), + Uniform::from(-T::PI()..T::PI()), + &mut rng, + ); + // realfft doesn't handle invalid input + angles.slice_mut(s![.., 0]).fill(T::zero()); + angles.slice_mut(s![.., -1]).fill(T::zero()); + &spectrogram * &angles.mapv(|t| Complex::from_polar(T::one(), t)) +} else { + spectrogram.clone() +}; +``` + +Now we have our initial guess at a complete spectrogram with magnitude and +phase components we can start to try and estimate better phase information. The +steps for this in Griffin-Lim is remarkably simple. + +1. Perform an inverse STFT to get audio data +2. Perform an STFT to go back to the frequency domain +3. Replace magnitude with our "known correct" magnitude and adjust phase +4. Repeat until stopping criteria is met + +But how does this work? Well, if we change the phase then the frequencies +intefere with each other and will change the resulting magnitude spectrum. +We don't want to change our magnitude so then reapplying it to the rebuilt spectrum +negates that inteference and we should end up with a phase that inteferes less. +This prioritises making the signal consistent in that each ISTFT/STFT pair result +in the same spectrogram. + +However, successive FFTs will cause the phase to gradually drift in one direction, +this can result in the warbly sound that people complain of with Griffin-Lim and +ultimately why it's been abandoned for slower but more accurate neural network based +vocoders. + +This is only a high level overview if you're interested in understanding this +more fully the paper for Griffin Lim can be found easily via Google. Also the maths +behind this generally comes from signal processing (STFT, Fourier Transforms) +and non-linear optimisation (BGFS). There are numerous courses/videos/books +that cover these if you want to find a resource to go deeper into it. + +### Some Sample Steps + +Where we start on our estimate: + +![image](./resources/estimate_0.png) + +After the first iteration: + +![image](./resources/estimate_1.png) + +The end: + +![image](./resources/estimate_10.png) + +Now I'm not sure where these blocks come from in the spectrogram generation, +the current feeling is it's a mistake in my plotting or some nuance of librosa +fiddling with the spectrogram for plotting purposes. Any insight on this would +be appreciated! + +## The Audio + +And listen to the output: + +![audio](https://github.com/emotechlab/griffin-lim/raw/main/tutorial/output.wav) + +## Conclusion + +Hopefully this was insightful to some of the process. One thing skipped is that +this is a "fast" version of Griffin-Lim which adds a momentum parameter to move +more quickly towards convergence. We also made substantial use of +[librosa](https://librosa.org/) as a reference implementation to test correctness +and examine how the algorithm works. diff --git a/tutorial/generate_plots.py b/tutorial/generate_plots.py new file mode 100644 index 0000000..99b0107 --- /dev/null +++ b/tutorial/generate_plots.py @@ -0,0 +1,68 @@ +from librosa import display, feature +import librosa +import numpy as np +import os +import matplotlib.pyplot as plt +import sys + +# Rerun the rust code +re_run = False + +def plot_complex_spec(spec, name): + magnitude, phase = librosa.magphase(spec) + fig, ax = plt.subplots(nrows=2) + + + img = librosa.display.specshow(magnitude, x_axis='time', + y_axis='linear', sr=22050, + ax=ax[0]) + img = librosa.display.specshow(np.angle(phase), x_axis='time', + y_axis='linear', sr=22050, + ax=ax[1]) + + ax[0].set_title("Magnitude Spectrum") + ax[1].set_title("Phase Spectrum") + plt.tight_layout() + plt.savefig(name) + plt.close() + +y, sr = librosa.load(librosa.ex('trumpet')) +S = librosa.stft(y) +plot_complex_spec(S, "resources/trumpet.png") + +S = np.load("hello_world.npy") +if re_run: + os.system(f"cargo run --features debug_dump --example run_griffin_lim -- --input hello_world.npy --power 1.7 --iters 10" ) + +mel_basis = librosa.filters.mel(sr=22050, n_fft=1024, n_mels=80) +print(mel_basis.shape) + +linspec =np.load("linear_spectrogram.npy") + +fig, ax = plt.subplots() +img = librosa.display.specshow(mel_basis, x_axis="linear") +ax.set(ylabel='Mel filter', title='Mel filter bank') +fig.colorbar(img, ax=ax) + +plt.savefig('resources/mel_basis.png', bbox_inches='tight') + + +fig, ax = plt.subplots() +img = librosa.display.specshow(S, x_axis='time', + y_axis='mel', sr=22050, + fmax=8000, ax=ax) +fig.colorbar(img, ax=ax, format='%+2.0f dB') + +plt.savefig('resources/mel_spec.png', bbox_inches='tight') + +S = np.load("linear_spectrogram.npy") +fig, ax = plt.subplots() +img = librosa.display.specshow(S, x_axis='time', + y_axis='linear', sr=22050, + ax=ax) +fig.colorbar(img, ax=ax, format='%+2.0f dB') + +plt.savefig('resources/linear_spec.png', bbox_inches='tight') + +for i in range(0, 11): + plot_complex_spec(np.load(f"estimate_spec_{i}.npy"), f"resources/estimate_{i}.png") diff --git a/tutorial/hello_world.npy b/tutorial/hello_world.npy new file mode 100644 index 0000000..c6a632e Binary files /dev/null and b/tutorial/hello_world.npy differ diff --git a/tutorial/output.wav b/tutorial/output.wav new file mode 100644 index 0000000..9cd54f1 Binary files /dev/null and b/tutorial/output.wav differ diff --git a/tutorial/requirements.txt b/tutorial/requirements.txt new file mode 100644 index 0000000..01b5021 --- /dev/null +++ b/tutorial/requirements.txt @@ -0,0 +1,3 @@ +librosa==0.9.2 +numpy==1.19.3 +matplotlib==3.3.4 diff --git a/tutorial/resources/estimate_0.png b/tutorial/resources/estimate_0.png new file mode 100644 index 0000000..50880f9 Binary files /dev/null and b/tutorial/resources/estimate_0.png differ diff --git a/tutorial/resources/estimate_1.png b/tutorial/resources/estimate_1.png new file mode 100644 index 0000000..63a8f95 Binary files /dev/null and b/tutorial/resources/estimate_1.png differ diff --git a/tutorial/resources/estimate_10.png b/tutorial/resources/estimate_10.png new file mode 100644 index 0000000..21421e7 Binary files /dev/null and b/tutorial/resources/estimate_10.png differ diff --git a/tutorial/resources/estimate_2.png b/tutorial/resources/estimate_2.png new file mode 100644 index 0000000..312c09b Binary files /dev/null and b/tutorial/resources/estimate_2.png differ diff --git a/tutorial/resources/estimate_3.png b/tutorial/resources/estimate_3.png new file mode 100644 index 0000000..f1c97d1 Binary files /dev/null and b/tutorial/resources/estimate_3.png differ diff --git a/tutorial/resources/estimate_4.png b/tutorial/resources/estimate_4.png new file mode 100644 index 0000000..da2e5b8 Binary files /dev/null and b/tutorial/resources/estimate_4.png differ diff --git a/tutorial/resources/estimate_5.png b/tutorial/resources/estimate_5.png new file mode 100644 index 0000000..e87000c Binary files /dev/null and b/tutorial/resources/estimate_5.png differ diff --git a/tutorial/resources/estimate_6.png b/tutorial/resources/estimate_6.png new file mode 100644 index 0000000..4e62fb3 Binary files /dev/null and b/tutorial/resources/estimate_6.png differ diff --git a/tutorial/resources/estimate_7.png b/tutorial/resources/estimate_7.png new file mode 100644 index 0000000..47f8a31 Binary files /dev/null and b/tutorial/resources/estimate_7.png differ diff --git a/tutorial/resources/estimate_8.png b/tutorial/resources/estimate_8.png new file mode 100644 index 0000000..9adbb8f Binary files /dev/null and b/tutorial/resources/estimate_8.png differ diff --git a/tutorial/resources/estimate_9.png b/tutorial/resources/estimate_9.png new file mode 100644 index 0000000..54b2421 Binary files /dev/null and b/tutorial/resources/estimate_9.png differ diff --git a/tutorial/resources/linear_spec.png b/tutorial/resources/linear_spec.png new file mode 100644 index 0000000..27506d2 Binary files /dev/null and b/tutorial/resources/linear_spec.png differ diff --git a/tutorial/resources/mel_basis.png b/tutorial/resources/mel_basis.png new file mode 100644 index 0000000..06afc24 Binary files /dev/null and b/tutorial/resources/mel_basis.png differ diff --git a/tutorial/resources/mel_spec.png b/tutorial/resources/mel_spec.png new file mode 100644 index 0000000..8a6c2a4 Binary files /dev/null and b/tutorial/resources/mel_spec.png differ diff --git a/tutorial/resources/trumpet.png b/tutorial/resources/trumpet.png new file mode 100644 index 0000000..36bfbb3 Binary files /dev/null and b/tutorial/resources/trumpet.png differ