Skip to content

Commit

Permalink
Whisper encode audio
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 12, 2024
1 parent 571224e commit 63e2224
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 12 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Run tests
- name: Run clippy
run: rustup update; cargo clippy --all-targets -- -D warnings

fmt:
Expand All @@ -37,7 +37,7 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Run tests
- name: Format
run: cargo fmt --all --check

# macos_test:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ members = [
"crates/luminal_nn",
"crates/luminal_training",
]
exclude = ["examples/whisper", "crates/luminal_metal", "crates/luminal_cuda"]
exclude = ["crates/luminal_metal", "crates/luminal_cuda"]
4 changes: 3 additions & 1 deletion examples/whisper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@ tokenizers = "0.15.2"
itertools = "0.12.1"
metal-rs = { version = "0.27.0", package = "metal", features = [
"mps",
] }
] }
symphonia = "0.5.4"
anyhow = "1.0.83"
Binary file added examples/whisper/setup/melfilters.bytes
Binary file not shown.
Binary file added examples/whisper/setup/melfilters128.bytes
Binary file not shown.
Binary file added examples/whisper/setup/samples_jfk.wav
Binary file not shown.
75 changes: 75 additions & 0 deletions examples/whisper/src/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,78 @@ pub fn log_mel_spectrogram_<T: Float>(
pub fn pcm_to_mel<T: Float>(n_mel: usize, samples: &[T], filters: &[T]) -> Vec<T> {
log_mel_spectrogram_(samples, filters, N_FFT, HOP_LENGTH, n_mel, false)
}

use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;

fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}

pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
// Open the media source.
let src = std::fs::File::open(path)?;

// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());

// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();

// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();

// Probe the media source.
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
// Get the instantiated format reader.
let mut format = probed.format;

// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.expect("no supported audio tracks");

// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();

// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}

// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}
23 changes: 22 additions & 1 deletion examples/whisper/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ fn main() {
// Construct encoder graph
let mut enc_cx = Graph::new();
let encoder = model::AudioEncoder::initialize(&mut enc_cx);
enc_cx.keep_tensors(params(&encoder));
let mut audio_input = enc_cx.tensor::<(Const<1>, Dyn<'s'>, Const<{ model::N_MEL_BINS }>)>();
let mut encoded = encoder.forward(audio_input).retrieve();
let mut encoded = encoder.forward(audio_input);
encoded.retrieve();
loader::load("setup/whisper.gguf", &encoder, &mut enc_cx);

// Construct decoder graph
Expand All @@ -34,6 +36,7 @@ fn main() {
)
})
.collect();
cache_src.set_dyn(vec![], &[1, 6, 0, 64]);
let (mut logits, mut enc_proj_states, mut cache_dest) = decoder.forward((
&encoder_output,
text_input,
Expand All @@ -59,4 +62,22 @@ fn main() {
&mut logits,
),
);

// Process audio into mel spectrogram
let mel_bytes = include_bytes!("../setup/melfilters.bytes").as_slice();
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let (pcm_data, sample_rate) = audio::pcm_decode("setup/samples_jfk.wav").unwrap();
let mel = audio::pcm_to_mel(80, &pcm_data, &mel_filters);
let mel_len = mel.len();

// Encode audio
audio_input.set_dyn(mel, &[1, mel_len / 80, 80]);
enc_cx.execute();
transfer_data(encoded, &mut enc_cx, encoder_output, &mut dec_cx);

// // Decode text
// for _ in 0..1000 {

// }
}
12 changes: 8 additions & 4 deletions examples/whisper/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ fn sinusoids<const CHANNELS: usize, Length: Dimension>(
impl<Batch: Dimension, Seq: Dimension> Module<GraphTensor<(Batch, Seq, Const<N_MEL_BINS>)>>
for AudioEncoder
{
type Output = GraphTensor<(Batch, Seq, Const<D_MODEL>)>;
type Output = Vec<GraphTensor<(Batch, Seq, Const<D_MODEL>)>>;
fn forward(&self, input: GraphTensor<(Batch, Seq, Const<N_MEL_BINS>)>) -> Self::Output {
// Conv layers
let x = self
Expand All @@ -321,9 +321,13 @@ impl<Batch: Dimension, Seq: Dimension> Module<GraphTensor<(Batch, Seq, Const<N_M
// Sinusoidal positional embedding
x += sinusoids::<D_MODEL, Seq>(x.graph()).expand();
// Transformer layers
let output = self.layers.forward(x);
// Final layer norm
output.layer_norm::<Axis<2>, _>(1e-5)
let mut outputs = vec![];
let mut output = x;
for l in &self.layers {
output = l.forward(output);
outputs.push(output);
}
outputs
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/graph_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ impl<S: Shape> GraphTensor<S> {
pub fn set_dyn<T: Data + Clone>(self, data: T, shape: &[usize]) -> Self {
// Report dyn dim values to graph dyn map
assert_eq!(
S::realized_shape().len(),
S::NUM_DIMS,
shape.len(),
"Number of dimensions don't match!"
"Number of dimensions do not match!"
);
for (d, s) in S::realized_shape().iter().zip(shape.iter()) {
if let Some(c) = d.to_symbols().pop() {
Expand Down
2 changes: 1 addition & 1 deletion src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ impl Serializer {
self.current_path.pop();
}
}
pub fn module<T: SerializeModule>(&mut self, name: &str, module: &T) {
pub fn module(&mut self, name: &str, module: impl SerializeModule) {
if !name.is_empty() {
// Add new path component
self.current_path.push(name.to_string());
Expand Down

0 comments on commit 63e2224

Please sign in to comment.