From fe9b9ff75edebf1e0e326f64577d64df3b5858db Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Wed, 9 Aug 2023 11:34:13 +0000 Subject: [PATCH] 2023-08-09 nightly release (9f5fa84bbd5cbea15d289e6b39c6b4fdfa1e27b3) --- .../asr/librispeech_conformer_rnnt/README.md | 6 +- .../asr/librispeech_conformer_rnnt/train.py | 6 +- examples/avsr/README.md | 2 +- .../ctc_forced_alignment_api_tutorial.py | 278 ++++----- ...lignment_for_multilingual_data_tutorial.py | 552 ++++++++---------- .../prototype/vggish_pipeline_test.py | 27 +- torchaudio/pipelines/_wav2vec2/aligner.py | 12 +- torchaudio/pipelines/_wav2vec2/impl.py | 26 +- torchaudio/pipelines/_wav2vec2/utils.py | 37 +- 9 files changed, 417 insertions(+), 529 deletions(-) diff --git a/examples/asr/librispeech_conformer_rnnt/README.md b/examples/asr/librispeech_conformer_rnnt/README.md index 4ad68c3165..720d9b2575 100644 --- a/examples/asr/librispeech_conformer_rnnt/README.md +++ b/examples/asr/librispeech_conformer_rnnt/README.md @@ -12,7 +12,7 @@ To build TorchAudio from source, refer to the [contributing guidelines](https:// ### Install additional dependencies ```bash -pip install pytorch-lightning sentencepiece +pip install pytorch-lightning sentencepiece tensorboard ``` ## Usage @@ -27,7 +27,7 @@ pip install pytorch-lightning sentencepiece Sample SLURM command: ``` -srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp_dir ./experiments --librispeech_path ./librispeech/ --global_stats_path ./global_stats.json --sp_model_path ./spm_unigram_1023.model --epochs 160 +srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp-dir ./experiments --librispeech-path ./librispeech/ --global-stats-path ./global_stats.json --sp-model-path ./spm_unigram_1023.model --epochs 160 ``` ### Evaluation @@ -36,7 +36,7 @@ srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train. Sample SLURM command: ``` -srun python eval.py --checkpoint_path ./experiments/checkpoints/epoch=159.ckpt --librispeech_path ./librispeech/ --sp_model_path ./spm_unigram_1023.model --use_cuda +srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=159.ckpt --librispeech-path ./librispeech/ --sp-model-path ./spm_unigram_1023.model --use-cuda ``` The table below contains WER results for various splits. diff --git a/examples/asr/librispeech_conformer_rnnt/train.py b/examples/asr/librispeech_conformer_rnnt/train.py index 530e3257b4..5d50d7e082 100644 --- a/examples/asr/librispeech_conformer_rnnt/train.py +++ b/examples/asr/librispeech_conformer_rnnt/train.py @@ -6,7 +6,7 @@ from lightning import ConformerRNNTModule from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint -from pytorch_lightning.plugins import DDPPlugin +from pytorch_lightning.strategies import DDPStrategy from transforms import get_data_module @@ -39,9 +39,9 @@ def run_train(args): default_root_dir=args.exp_dir, max_epochs=args.epochs, num_nodes=args.nodes, - gpus=args.gpus, + devices=args.gpus, accelerator="gpu", - strategy=DDPPlugin(find_unused_parameters=False), + strategy=DDPStrategy(find_unused_parameters=False), callbacks=callbacks, reload_dataloaders_every_n_epochs=1, gradient_clip_val=10.0, diff --git a/examples/avsr/README.md b/examples/avsr/README.md index ea18c9175a..65bd3621db 100644 --- a/examples/avsr/README.md +++ b/examples/avsr/README.md @@ -12,7 +12,7 @@ This directory contains the training recipe for real-time audio, visual, and audio-visual speech recognition (ASR, VSR, AV-ASR) models, which is an extension of [Auto-AVSR](https://arxiv.org/abs/2303.14307). -Please refer to [this tutorial]() for real-time AV-ASR inference from microphone and camera. +Please refer to [this tutorial](https://pytorch.org/audio/main/tutorials/device_avsr.html) for real-time AV-ASR inference from microphone and camera. ## Preparation diff --git a/examples/tutorials/ctc_forced_alignment_api_tutorial.py b/examples/tutorials/ctc_forced_alignment_api_tutorial.py index 33af5f8853..d6282a07b8 100644 --- a/examples/tutorials/ctc_forced_alignment_api_tutorial.py +++ b/examples/tutorials/ctc_forced_alignment_api_tutorial.py @@ -2,50 +2,36 @@ CTC forced alignment API tutorial ================================= -**Author**: `Xiaohui Zhang `__ - - -This tutorial shows how to align transcripts to speech using -:py:func:`torchaudio.functional.forced_align` -which was developed along the work of -`Scaling Speech Technology to 1,000+ Languages `__. +**Author**: `Xiaohui Zhang `__, `Moto Hira `__ The forced alignment is a process to align transcript with speech. -We cover the basics of forced alignment in `Forced Alignment with -Wav2Vec2 <./forced_alignment_tutorial.html>`__ with simplified -step-by-step Python implementations. +This tutorial shows how to align transcripts to speech using +:py:func:`torchaudio.functional.forced_align` which was developed along the work of +`Scaling Speech Technology to 1,000+ Languages +`__. :py:func:`~torchaudio.functional.forced_align` has custom CPU and CUDA implementations which are more performant than the vanilla Python implementation above, and are more accurate. It can also handle missing transcript with special token. -For examples of aligning multiple languages, please refer to -`Forced alignment for multilingual data <./forced_alignment_for_multilingual_data_tutorial.html>`__. +There is also a high-level API, :py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, +which wraps the pre/post-processing explained in this tutorial and makes it easy +to run forced-alignments. +`Forced alignment for multilingual data +<./forced_alignment_for_multilingual_data_tutorial.html>`__ uses this API to +illustrate how to align non-English transcripts. """ import torch import torchaudio - print(torch.__version__) print(torchaudio.__version__) ###################################################################### # -from dataclasses import dataclass -from typing import List - -import IPython -import matplotlib.pyplot as plt - -###################################################################### -# - -from torchaudio.functional import forced_align - -torch.random.manual_seed(0) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) @@ -53,16 +39,24 @@ # Preparation # ----------- # +import IPython +import matplotlib.pyplot as plt + +import torchaudio.functional as F + +###################################################################### # First we prepare the speech data and the transcript we area going # to use. # SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav") -TRANSCRIPT = "I|HAD|THAT|CURIOSITY|BESIDE|ME|AT|THIS|MOMENT" +waveform, _ = torchaudio.load(SPEECH_FILE) +TRANSCRIPT = "i had that curiosity beside me at this moment".split() + ###################################################################### -# Generating emissions and tokens -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Generating emissions +# ~~~~~~~~~~~~~~~~~~~~ # # :py:func:`~torchaudio.functional.forced_align` takes emission and # token sequences and outputs timestaps of the tokens and their scores. @@ -70,30 +64,26 @@ # Emission reperesents the frame-wise probability distribution over # tokens, and it can be obtained by passing waveform to an acoustic # model. -# Tokens are numerical expression of transcripts. It can be obtained by -# simply mapping each character to the index of token list. -# The emission and the token sequences must be using the same set of tokens. # -# We can use pre-trained Wav2Vec2 model to obtain emission from speech, -# and map transcript to tokens. -# Here, we use :py:data:`~torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H`, -# which bandles pre-trained model weights with associated labels. +# Tokens are numerical expression of transcripts. There are many ways to +# tokenize transcripts, but here, we simply map alphabets into integer, +# which is how labels were constructed when the acoustice model we are +# going to use was trained. +# +# We will use a pre-trained Wav2Vec2 model, +# :py:data:`torchaudio.pipelines.MMS_FA`, to obtain emission and tokenize +# the transcript. # -bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H -model = bundle.get_model().to(device) +bundle = torchaudio.pipelines.MMS_FA + +model = bundle.get_model(with_star=False).to(device) with torch.inference_mode(): - waveform, _ = torchaudio.load(SPEECH_FILE) emission, _ = model(waveform.to(device)) - emission = torch.log_softmax(emission, dim=-1) - -num_frames = emission.size(1) ###################################################################### # - - def plot_emission(emission): fig, ax = plt.subplots() ax.imshow(emission.cpu().T) @@ -106,20 +96,24 @@ def plot_emission(emission): plot_emission(emission[0]) ###################################################################### +# Tokenize the transcript +# ~~~~~~~~~~~~~~~~~~~~~~~ +# # We create a dictionary, which maps each label into token. -labels = bundle.get_labels() -DICTIONARY = {c: i for i, c in enumerate(labels)} - +LABELS = bundle.get_labels(star=None) +DICTIONARY = bundle.get_dict(star=None) for k, v in DICTIONARY.items(): print(f"{k}: {v}") ###################################################################### # converting transcript to tokens is as simple as -tokenized_transcript = [DICTIONARY[c] for c in TRANSCRIPT] +tokenized_transcript = [DICTIONARY[c] for word in TRANSCRIPT for c in word] -print(" ".join(str(t) for t in tokenized_transcript)) +for t in tokenized_transcript: + print(t, end=" ") +print() ###################################################################### # Computing frame-level alignments @@ -129,17 +123,11 @@ def plot_emission(emission): # frame-level alignment. For the detail of function signature, please # refer to :py:func:`~torchaudio.functional.forced_align`. # -# def align(emission, tokens): - alignments, scores = forced_align( - emission, - targets=torch.tensor([tokens], dtype=torch.int32, device=emission.device), - input_lengths=torch.tensor([emission.size(1)], device=emission.device), - target_lengths=torch.tensor([len(tokens)], device=emission.device), - blank=0, - ) + targets = torch.tensor([tokens], dtype=torch.int32, device=device) + alignments, scores = F.forced_align(emission, targets, blank=0) alignments, scores = alignments[0], scores[0] # remove batch dimension for simplicity scores = scores.exp() # convert back to probability @@ -154,7 +142,7 @@ def align(emission, tokens): # emission, which is different from the original waveform. for i, (ali, score) in enumerate(zip(aligned_tokens, alignment_scores)): - print(f"{i:3d}:\t{ali:2d} [{labels[ali]}], {score:.2f}") + print(f"{i:3d}:\t{ali:2d} [{LABELS[ali]}], {score:.2f}") ###################################################################### # @@ -209,46 +197,14 @@ def align(emission, tokens): # which explains what token (in transcript) is present at what time span. -@dataclass -class TokenSpan: - index: int # index of token in transcript - start: int # start time (inclusive) - end: int # end time (exclusive) - score: float - - def __len__(self) -> int: - return self.end - self.start - - ###################################################################### # - -def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]: - prev_token = blank - i = start = -1 - spans = [] - for t, token in enumerate(tokens): - if token != prev_token: - if prev_token != blank: - spans.append(TokenSpan(i, start, t, scores[start:t].mean().item())) - if token != blank: - i += 1 - start = t - prev_token = token - if prev_token != blank: - spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item())) - return spans - - -###################################################################### -# - -token_spans = merge_tokens(aligned_tokens, alignment_scores) +token_spans = F.merge_tokens(aligned_tokens, alignment_scores) print("Token\tTime\tScore") for s in token_spans: - print(f"{TRANSCRIPT[s.index]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}") + print(f"{LABELS[s.token]}\t[{s.start:3d}, {s.end:3d})\t{s.score:.2f}") ###################################################################### # Visualization @@ -256,18 +212,17 @@ def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]: # -def plot_scores(spans, scores, transcript): +def plot_scores(spans, scores): fig, ax = plt.subplots() ax.set_title("frame-level and token-level confidence scores") span_xs, span_hs, span_ws = [], [], [] frame_xs, frame_hs = [], [] for span in spans: - token = transcript[span.index] - if token != "|": + if LABELS[span.token] != "|": span_xs.append((span.end + span.start) / 2 + 0.4) span_hs.append(span.score) span_ws.append(span.end - span.start) - ax.annotate(token, (span.start + 0.8, -0.07), weight="bold") + ax.annotate(LABELS[span.token], (span.start + 0.8, -0.07), weight="bold") for t in range(span.start, span.end): frame_xs.append(t + 1) frame_hs.append(scores[t].item()) @@ -279,7 +234,7 @@ def plot_scores(spans, scores, transcript): fig.tight_layout() -plot_scores(token_spans, alignment_scores, TRANSCRIPT) +plot_scores(token_spans, alignment_scores) ###################################################################### @@ -295,30 +250,18 @@ def plot_scores(spans, scores, transcript): # alignments and listening to them. -@dataclass -class WordSpan: - token_spans: List[TokenSpan] - score: float - - # Obtain word alignments from token alignments -def merge_words(token_spans, transcript, separator="|") -> List[WordSpan]: - def _score(t_spans): - return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans) - - words = [] +def unflatten(list_, lengths): + assert len(list_) == sum(lengths) i = 0 - - for j, span in enumerate(token_spans): - if transcript[span.index] == separator: - words.append(WordSpan(token_spans[i:j], _score(token_spans[i:j]))) - i = j + 1 - if i < len(token_spans): - words.append(WordSpan(token_spans[i:], _score(token_spans[i:]))) - return words + ret = [] + for l in lengths: + ret.append(list_[i : i + l]) + i += l + return ret -word_spans = merge_words(token_spans, TRANSCRIPT) +word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT]) ###################################################################### @@ -326,45 +269,50 @@ def _score(t_spans): # ~~~~~~~~~~~~~ # +# Compute average score weighted by the span length +def _score(spans): + return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) -def plot_alignments(waveform, word_spans, num_frames, transcript, sample_rate=bundle.sample_rate): - fig, ax = plt.subplots() - ax.specgram(waveform[0], Fs=sample_rate) - ratio = waveform.size(1) / sample_rate / num_frames - for w_span in word_spans: - t_spans = w_span.token_spans - t0, t1 = t_spans[0].start, t_spans[-1].end - ax.axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white") - ax.annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) +def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate): + ratio = waveform.size(1) / emission.size(1) / sample_rate - for span in t_spans: - token = transcript[span.index] - ax.annotate(token, (span.start * ratio, sample_rate * 0.53), annotation_clip=False) + fig, axes = plt.subplots(2, 1) + axes[0].imshow(emission[0].detach().cpu().T, aspect="auto") + axes[0].set_title("Emission") + axes[0].set_xticks([]) - ax.set_xlabel("time [second]") - ax.set_xlim([0, None]) - fig.tight_layout() + axes[1].specgram(waveform[0], Fs=sample_rate) + for t_spans, chars in zip(token_spans, transcript): + t0, t1 = t_spans[0].start + 0.1, t_spans[-1].end - 0.1 + axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white") + axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white") + axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) + for span, char in zip(t_spans, chars): + t0 = span.start * ratio + axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False) -plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT) + axes[1].set_xlabel("time [second]") + axes[1].set_xlim([0, None]) + fig.tight_layout() -###################################################################### +plot_alignments(waveform, word_spans, emission, TRANSCRIPT) -def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle.sample_rate): +###################################################################### +def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / num_frames - t0 = word_span.token_spans[0].start - t1 = word_span.token_spans[-1].end - x0 = int(ratio * t0) - x1 = int(ratio * t1) - tokens = "".join(transcript[t.index] for t in word_span.token_spans) - print(f"{tokens} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") + x0 = int(ratio * spans[0].start) + x1 = int(ratio * spans[-1].end) + print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=sample_rate) +num_frames = emission.size(1) + ###################################################################### # Generate the audio for each segment @@ -374,47 +322,47 @@ def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[0], num_frames, TRANSCRIPT[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[1], num_frames, TRANSCRIPT[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[2], num_frames, TRANSCRIPT[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[3], num_frames, TRANSCRIPT[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[4], num_frames, TRANSCRIPT[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[5], num_frames, TRANSCRIPT[5]) ###################################################################### # -preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[6], num_frames, TRANSCRIPT[6]) ###################################################################### # -preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[7], num_frames, TRANSCRIPT[7]) ###################################################################### # -preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT) +preview_word(waveform, word_spans[8], num_frames, TRANSCRIPT[8]) ###################################################################### @@ -442,7 +390,7 @@ def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle # corresponding to the ```` token. # -star_dim = torch.zeros((1, num_frames, 1), device=device) +star_dim = torch.zeros((1, emission.size(1), 1), device=emission.device, dtype=emission.dtype) emission = torch.cat((emission, star_dim), 2) assert len(DICTIONARY) == emission.shape[2] @@ -455,10 +403,10 @@ def preview_word(waveform, word_span, num_frames, transcript, sample_rate=bundle def compute_alignments(emission, transcript, dictionary): - tokens = [dictionary[c] for c in transcript] + tokens = [dictionary[char] for word in transcript for char in word] alignment, scores = align(emission, tokens) - token_spans = merge_tokens(alignment, scores) - word_spans = merge_words(token_spans, transcript) + token_spans = F.merge_tokens(alignment, scores) + word_spans = unflatten(token_spans, [len(word) for word in transcript]) return word_spans @@ -466,26 +414,31 @@ def compute_alignments(emission, transcript, dictionary): # **Original** word_spans = compute_alignments(emission, TRANSCRIPT, DICTIONARY) -plot_alignments(waveform, word_spans, num_frames, TRANSCRIPT) +plot_alignments(waveform, word_spans, emission, TRANSCRIPT) ###################################################################### # **With token** # # Now we replace the first part of the transcript with the ```` token. -transcript = "*|THIS|MOMENT" +transcript = "* this moment".split() word_spans = compute_alignments(emission, transcript, DICTIONARY) -plot_alignments(waveform, word_spans, num_frames, transcript) +plot_alignments(waveform, word_spans, emission, transcript) + +###################################################################### +# + +preview_word(waveform, word_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript) +preview_word(waveform, word_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript) +preview_word(waveform, word_spans[2], num_frames, transcript[2]) ###################################################################### # @@ -497,9 +450,9 @@ def compute_alignments(emission, transcript, dictionary): # without using ```` token. # It demonstrates the effect of ```` token for dealing with deletion errors. -transcript = "THIS|MOMENT" +transcript = "this moment".split() word_spans = compute_alignments(emission, transcript, DICTIONARY) -plot_alignments(waveform, word_spans, num_frames, transcript) +plot_alignments(waveform, word_spans, emission, transcript) ###################################################################### # Conclusion @@ -517,7 +470,6 @@ def compute_alignments(emission, transcript, dictionary): # --------------- # # Thanks to `Vineel Pratap `__ and `Zhaoheng -# Ni `__ for working on the forced aligner API, and `Moto -# Hira `__ for providing alignment merging and -# visualization utilities. +# Ni `__ for developing and open-sourcing the +# forced aligner API. # diff --git a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py index 7416f0f8b9..485dbd097b 100644 --- a/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py +++ b/examples/tutorials/forced_alignment_for_multilingual_data_tutorial.py @@ -2,18 +2,16 @@ Forced alignment for multilingual data ====================================== -**Author**: `Xiaohui Zhang `__ - -This tutorial shows how to compute forced alignments for speech data -from multiple non-English languages using ``torchaudio``'s CTC forced alignment -API described in `CTC forced alignment tutorial <./forced_alignment_tutorial.html>`__ -and the multilingual Wav2vec2 model proposed in the paper `Scaling -Speech Technology to 1,000+ -Languages `__. - -The model was trained on 23K of audio data from 1100+ languages using -the `uroman vocabulary `__ -as targets. +**Authors**: `Xiaohui Zhang `__, `Moto Hira `__. + +This tutorial shows how to align transcript to speech for non-English languages. + +The process of aligning non-English (normalized) transcript is identical to aligning +English (normalized) transcript, and the process for English is covered in detail in +`CTC forced alignment tutorial <./ctc_forced_alignment_api_tutorial.html>`__. +In this tutorial, we use TorchAudio's high-level API, +:py:class:`torchaudio.pipelines.Wav2Vec2FABundle`, which packages the pre-trained +model, tokenizer and aligner, to perform the forced alignment with less code. """ import torch @@ -25,114 +23,109 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) -from dataclasses import dataclass ###################################################################### -# Preparation -# ----------- # -from typing import Dict, List +from typing import List import IPython import matplotlib.pyplot as plt -from torchaudio.functional import forced_align - ###################################################################### +# Creating the pipeline +# --------------------- # - -SAMPLE_RATE = 16000 - - -###################################################################### +# First, we instantiate the model and pre/post-processing pipelines. # -# Here we define utility functions for computing the frame-level -# alignments (using the API :py:func:`torchaudio.functional.forced_align`), -# token-level and word-level alignments. -# For the detail of these functions please refer to -# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__. +# The following diagram illustrates the process of alignment. # +# .. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2fabundle.png +# +# The waveform is passed to an acoustic model, which produces the sequence of +# probability distribution of tokens. +# The transcript is passed to tokenizer, which converts the transcript to +# sequence of tokens. +# Aligner takes the results from the acoustic model and the tokenizer and generate +# timestamps for each token. +# +# .. note:: +# +# This process expects that the input transcript is already normalized. +# The process of normalization, which involves romanization of non-English +# languages, is language-dependent, so it is not covered in this tutorial, +# but we will breifly look into it. +# +# The acoustic model and the tokenizer must use the same set of tokens. +# To facilitate the creation of matching processors, +# :py:class:`~torchaudio.pipelines.Wav2Vec2FABundle` associates a +# pre-trained accoustic model and a tokenizer. +# :py:data:`torchaudio.pipelines.MMS_FA` is one of such instance. +# +# The following code instantiates a pre-trained acoustic model, a tokenizer +# which uses the same set of tokens as the model, and an aligner. +# +from torchaudio.pipelines import MMS_FA as bundle +model = bundle.get_model() +model.to(device) -@dataclass -class TokenSpan: - index: int # index of token in transcript - start: int # start time (inclusive) - end: int # end time (exclusive) - score: float - - def __len__(self) -> int: - return self.end - self.start +tokenizer = bundle.get_tokenizer() +aligner = bundle.get_aligner() ###################################################################### +# .. note:: +# +# The model instantiated by :py:data:`~torchaudio.pipelines.MMS_FA`'s +# :py:meth:`~torchaudio.pipelines.Wav2Vec2FABundle.get_model` +# method by default includes the feature dimension for ```` token. +# You can disable this by passing ``with_star=False``. # - - -@dataclass -class WordSpan: - token_spans: List[TokenSpan] - score: float - ###################################################################### +# The acoustic model of :py:data:`~torchaudio.pipelines.MMS_FA` was +# created and open-sourced as part of the research project, +# `Scaling Speech Technology to 1,000+ Languages +# `__. +# It was trained with 23,000 hours of audio from 1100+ languages. # -def align_emission_and_tokens(emission: torch.Tensor, tokens: List[int]): - device = emission.device - targets = torch.tensor([tokens], dtype=torch.int32, device=device) - input_lengths = torch.tensor([emission.size(1)], device=device) - target_lengths = torch.tensor([targets.size(1)], device=device) +# The tokenizer simply maps the normalized characters to integers. +# You can check the mapping as follow; - aligned_tokens, scores = forced_align(emission, targets, input_lengths, target_lengths, 0) +print(bundle.get_dict()) - scores = scores.exp() # convert back to probability - aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension - return aligned_tokens, scores - - -def merge_tokens(tokens, scores, blank=0) -> List[TokenSpan]: - prev_token = blank - i = start = -1 - spans = [] - for t, token in enumerate(tokens): - if token != prev_token: - if prev_token != blank: - spans.append(TokenSpan(i, start, t, scores[start:t].mean().item())) - if token != blank: - i += 1 - start = t - prev_token = token - if prev_token != blank: - spans.append(TokenSpan(i, start, len(tokens), scores[start:].mean().item())) - return spans +###################################################################### +# +# The aligner internally uses :py:func:`torchaudio.functional.forced_align` +# and :py:func:`torchaudio.functional.merge_tokens` to infer the time +# stamps of the input tokens. +# +# The detail of the underlying mechanism is covered in +# `CTC forced alignment API tutorial <./ctc_forced_alignment_api_tutorial.html>`__, +# so please refer to it. -def merge_words(token_spans: List[TokenSpan], transcript: List[str]) -> List[WordSpan]: - def _score(t_spans): - return sum(s.score * len(s) for s in t_spans) / sum(len(s) for s in t_spans) +###################################################################### +# We define a utility function that performs the forced alignment with +# the above model, the tokenizer and the aligner. +# +def compute_alignments(waveform: torch.Tensor, transcript: List[str]): + with torch.inference_mode(): + emission, _ = model(waveform.to(device)) + token_spans = aligner(emission[0], tokenizer(transcript)) + return emission, token_spans - word_spans = [] - i = 0 - for words in transcript: - j = i + len(words) - word_spans.append(WordSpan(token_spans[i:j], _score(token_spans[i:j]))) - i = j - return word_spans +###################################################################### +# We also define utility functions for plotting the result and previewing +# the audio segments. -def compute_alignments(emission: torch.Tensor, transcript: List[str], dictionary: Dict[str, int]): - tokens = [dictionary[c] for word in transcript for c in word] - aligned_tokens, scores = align_emission_and_tokens(emission, tokens) - token_spans = merge_tokens(aligned_tokens, scores) - word_spans = merge_words(token_spans, transcript) - return word_spans +# Compute average score weighted by the span length +def _score(spans): + return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans) -###################################################################### -# - -# utility function for plotting word alignments -def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMPLE_RATE): +def plot_alignments(waveform, token_spans, emission, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / emission.size(1) / sample_rate fig, axes = plt.subplots(2, 1) @@ -141,150 +134,68 @@ def plot_alignments(waveform, word_spans, emission, transcript, sample_rate=SAMP axes[0].set_xticks([]) axes[1].specgram(waveform[0], Fs=sample_rate) - for w_span, chars in zip(word_spans, transcript): - t_spans = w_span.token_spans + for t_spans, chars in zip(token_spans, transcript): t0, t1 = t_spans[0].start, t_spans[-1].end + axes[0].axvspan(t0 - 0.5, t1 - 0.5, facecolor="None", hatch="/", edgecolor="white") axes[1].axvspan(ratio * t0, ratio * t1, facecolor="None", hatch="/", edgecolor="white") - axes[1].annotate(f"{w_span.score:.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) + axes[1].annotate(f"{_score(t_spans):.2f}", (ratio * t0, sample_rate * 0.51), annotation_clip=False) for span, char in zip(t_spans, chars): - axes[1].annotate(char, (span.start * ratio, sample_rate * 0.55), annotation_clip=False) + t0 = span.start * ratio + axes[1].annotate(char, (t0, sample_rate * 0.55), annotation_clip=False) axes[1].set_xlabel("time [second]") fig.tight_layout() - return IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # - - -def preview_word(waveform, word_span, num_frames, transcript, sample_rate=SAMPLE_RATE): +def preview_word(waveform, spans, num_frames, transcript, sample_rate=bundle.sample_rate): ratio = waveform.size(1) / num_frames - t0 = word_span.token_spans[0].start - t1 = word_span.token_spans[-1].end - x0 = int(ratio * t0) - x1 = int(ratio * t1) - print(f"{transcript} ({word_span.score:.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") + x0 = int(ratio * spans[0].start) + x1 = int(ratio * spans[-1].end) + print(f"{transcript} ({_score(spans):.2f}): {x0 / sample_rate:.3f} - {x1 / sample_rate:.3f} sec") segment = waveform[:, x0:x1] return IPython.display.Audio(segment.numpy(), rate=sample_rate) ###################################################################### -# Aligning multilingual data +# Normalizing the transcript # -------------------------- # -# Here we show examples of computing forced alignments of utterances in -# 5 languages using the multilingual Wav2vec2 model, with the alignments visualized. -# One can also play the whole audio and audio segments aligned with each word, in -# order to verify the alignment quality. Here we first load the model and dictionary. -# - -from torchaudio.models import wav2vec2_model - -model = wav2vec2_model( - extractor_mode="layer_norm", - extractor_conv_layer_config=[ - (512, 10, 5), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 3, 2), - (512, 2, 2), - (512, 2, 2), - ], - extractor_conv_bias=True, - encoder_embed_dim=1024, - encoder_projection_dropout=0.0, - encoder_pos_conv_kernel=128, - encoder_pos_conv_groups=16, - encoder_num_layers=24, - encoder_num_heads=16, - encoder_attention_dropout=0.0, - encoder_ff_interm_features=4096, - encoder_ff_interm_dropout=0.1, - encoder_dropout=0.0, - encoder_layer_norm_first=True, - encoder_layer_drop=0.1, - aux_num_out=31, -) - - -model.load_state_dict( - torch.hub.load_state_dict_from_url( - "https://dl.fbaipublicfiles.com/mms/torchaudio/ctc_alignment_mling_uroman/model.pt" - ) -) -model.eval() -model.to(device) - - -def get_emission(waveform): - with torch.inference_mode(): - # NOTE: this step is essential - waveform = torch.nn.functional.layer_norm(waveform, waveform.shape) - emission, _ = model(waveform) - return torch.log_softmax(emission, dim=-1) - - -# Construct the dictionary -# '@' represents the OOV token -# and are fairseq's legacy tokens, which're not used. -# token is omitted as we do not use it in this tutorial -dictionary = { - "": 0, - "": 1, - "": 2, - "@": 3, - "a": 4, - "i": 5, - "e": 6, - "n": 7, - "o": 8, - "u": 9, - "t": 10, - "s": 11, - "r": 12, - "m": 13, - "k": 14, - "l": 15, - "d": 16, - "g": 17, - "h": 18, - "y": 19, - "b": 20, - "p": 21, - "w": 22, - "c": 23, - "v": 24, - "j": 25, - "z": 26, - "f": 27, - "'": 28, - "q": 29, - "x": 30, -} - - -###################################################################### -# Before aligning the speech with transcripts, we need to make sure -# the transcripts are already romanized. Here are the BASH commands -# required for saving raw transcript to a file, downloading the uroman -# romanizer and using it to obtain romanized transcripts, and PyThon -# commands required for further normalizing the romanized transcript. +# The transcripts passed to the pipeline must be normalized beforehand. +# The exact process of normalization depends on language. +# +# Languages that do not have explicit word boundaries +# (such as Chinese, Japanese and Korean) require segmentation first. +# There are dedicated tools for this, but let's say we have segmented +# transcript. +# +# The first step of normalization is romanization. +# `uroman `__ is a tool that +# supports many languages. +# +# Here is a BASH commands to romanize the input text file and write +# the output to another text file using ``uroman``. # # .. code-block:: bash # -# Save the raw transcript to a file -# echo 'raw text' > text.txt -# git clone https://github.com/isi-nlp/uroman -# uroman/bin/uroman.pl < text.txt > text_romanized.txt +# $ echo "des événements d'actualité qui se sont produits durant l'année 1882" > text.txt +# $ uroman/bin/uroman.pl < text.txt > text_romanized.txt +# $ cat text_romanized.txt +# +# .. code-block:: text +# +# Cette page concerne des evenements d'actualite qui se sont produits durant l'annee 1882 +# +# The next step is to remove non-alphabets and punctuations. +# The following snippet normalizes the romanized transcript. # - -###################################################################### # .. code-block:: python # # import re +# +# # def normalize_uroman(text): # text = text.lower() # text = text.replace("’", "'") @@ -292,80 +203,99 @@ def get_emission(waveform): # text = re.sub(' +', ' ', text) # return text.strip() # -# file = "text_romanized.txt" -# f = open(file, "r") -# lines = f.readlines() -# text_normalized = normalize_uroman(lines[0].strip()) # - +# with open("text_romanized.txt", "r") as f: +# for line in f: +# text_normalized = normalize_uroman(line) +# print(text_normalized) +# +# Running the script on the above exanple produces the following. +# +# .. code-block:: text +# +# cette page concerne des evenements d'actualite qui se sont produits durant l'annee +# +# Note that, in this example, since "1882" was not romanized by ``uroman``, +# it was removed in the normalization step. +# To avoid this, one needs to romanize numbers, but this is known to be a non-trivial task. +# ###################################################################### +# Aligning transcripts to speech +# ------------------------------ +# +# Now we perform the forced alignment for multiple languages. +# +# # German # ~~~~~~ -speech_file = torchaudio.utils.download_asset("tutorial-assets/10349_8674_000087.flac", progress=False) - text_raw = "aber seit ich bei ihnen das brot hole" text_normalized = "aber seit ich bei ihnen das brot hole" -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) +url = "https://download.pytorch.org/torchaudio/tutorial-assets/10349_8674_000087.flac" +waveform, sample_rate = torchaudio.load( + url, frame_offset=int(0.5 * bundle.sample_rate), num_frames=int(2.5 * bundle.sample_rate) +) ###################################################################### # - -waveform, _ = torchaudio.load(speech_file, frame_offset=int(0.5 * SAMPLE_RATE), num_frames=int(2.5 * SAMPLE_RATE)) - -emission = get_emission(waveform.to(device)) -num_frames = emission.size(1) +assert sample_rate == bundle.sample_rate ###################################################################### # transcript = text_normalized.split() -word_spans = compute_alignments(emission, transcript, dictionary) +tokens = tokenizer(transcript) + +emission, token_spans = compute_alignments(waveform, transcript) +num_frames = emission.size(1) -plot_alignments(waveform, word_spans, emission, transcript) +plot_alignments(waveform, token_spans, emission, transcript) + +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) +IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, transcript[0]) +preview_word(waveform, token_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript[1]) +preview_word(waveform, token_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript[2]) +preview_word(waveform, token_spans[2], num_frames, transcript[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, transcript[3]) +preview_word(waveform, token_spans[3], num_frames, transcript[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, transcript[4]) +preview_word(waveform, token_spans[4], num_frames, transcript[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, transcript[5]) +preview_word(waveform, token_spans[5], num_frames, transcript[5]) ###################################################################### # -preview_word(waveform, word_spans[6], num_frames, transcript[6]) +preview_word(waveform, token_spans[6], num_frames, transcript[6]) ###################################################################### # -preview_word(waveform, word_spans[7], num_frames, transcript[7]) +preview_word(waveform, token_spans[7], num_frames, transcript[7]) ###################################################################### # Chinese @@ -379,276 +309,277 @@ def get_emission(waveform): # However this is not needed if you only want character-level alignments. # -speech_file = torchaudio.utils.download_asset("tutorial-assets/mvdr/clean_speech.wav", progress=False) - text_raw = "关 服务 高端 产品 仍 处于 供不应求 的 局面" text_normalized = "guan fuwu gaoduan chanpin reng chuyu gongbuyingqiu de jumian" -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) - ###################################################################### # -waveform, _ = torchaudio.load(speech_file) +url = "https://download.pytorch.org/torchaudio/tutorial-assets/mvdr/clean_speech.wav" +waveform, sample_rate = torchaudio.load(url) waveform = waveform[0:1] -emission = get_emission(waveform.to(device)) -num_frames = emission.size(1) +###################################################################### +# +assert sample_rate == bundle.sample_rate ###################################################################### # transcript = text_normalized.split() -word_spans = compute_alignments(emission, transcript, dictionary) +emission, token_spans = compute_alignments(waveform, transcript) +num_frames = emission.size(1) + +plot_alignments(waveform, token_spans, emission, transcript) -plot_alignments(waveform, word_spans, emission, transcript) +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) +IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, transcript[0]) +preview_word(waveform, token_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript[1]) +preview_word(waveform, token_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript[2]) +preview_word(waveform, token_spans[2], num_frames, transcript[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, transcript[3]) +preview_word(waveform, token_spans[3], num_frames, transcript[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, transcript[4]) +preview_word(waveform, token_spans[4], num_frames, transcript[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, transcript[5]) +preview_word(waveform, token_spans[5], num_frames, transcript[5]) ###################################################################### # -preview_word(waveform, word_spans[6], num_frames, transcript[6]) +preview_word(waveform, token_spans[6], num_frames, transcript[6]) ###################################################################### # -preview_word(waveform, word_spans[7], num_frames, transcript[7]) +preview_word(waveform, token_spans[7], num_frames, transcript[7]) ###################################################################### # -preview_word(waveform, word_spans[8], num_frames, transcript[8]) +preview_word(waveform, token_spans[8], num_frames, transcript[8]) ###################################################################### # Polish # ~~~~~~ -speech_file = torchaudio.utils.download_asset("tutorial-assets/5090_1447_000088.flac", progress=False) - text_raw = "wtedy ujrzałem na jego brzuchu okrągłą czarną ranę" text_normalized = "wtedy ujrzalem na jego brzuchu okragla czarna rane" -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) +url = "https://download.pytorch.org/torchaudio/tutorial-assets/5090_1447_000088.flac" +waveform, sample_rate = torchaudio.load(url, num_frames=int(4.5 * bundle.sample_rate)) ###################################################################### # - -waveform, _ = torchaudio.load(speech_file, num_frames=int(4.5 * SAMPLE_RATE)) - -emission = get_emission(waveform.to(device)) -num_frames = emission.size(1) +assert sample_rate == bundle.sample_rate ###################################################################### # transcript = text_normalized.split() -word_spans = compute_alignments(emission, transcript, dictionary) +emission, token_spans = compute_alignments(waveform, transcript) +num_frames = emission.size(1) -plot_alignments(waveform, word_spans, emission, transcript) +plot_alignments(waveform, token_spans, emission, transcript) + +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) +IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, transcript[0]) +preview_word(waveform, token_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript[1]) +preview_word(waveform, token_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript[2]) +preview_word(waveform, token_spans[2], num_frames, transcript[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, transcript[3]) +preview_word(waveform, token_spans[3], num_frames, transcript[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, transcript[4]) +preview_word(waveform, token_spans[4], num_frames, transcript[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, transcript[5]) +preview_word(waveform, token_spans[5], num_frames, transcript[5]) ###################################################################### # -preview_word(waveform, word_spans[6], num_frames, transcript[6]) +preview_word(waveform, token_spans[6], num_frames, transcript[6]) ###################################################################### # -preview_word(waveform, word_spans[7], num_frames, transcript[7]) +preview_word(waveform, token_spans[7], num_frames, transcript[7]) ###################################################################### # Portuguese # ~~~~~~~~~~ -speech_file = torchaudio.utils.download_asset("tutorial-assets/6566_5323_000027.flac", progress=False) - text_raw = "na imensa extensão onde se esconde o inconsciente imortal" text_normalized = "na imensa extensao onde se esconde o inconsciente imortal" -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) +url = "https://download.pytorch.org/torchaudio/tutorial-assets/6566_5323_000027.flac" +waveform, sample_rate = torchaudio.load( + url, frame_offset=int(bundle.sample_rate), num_frames=int(4.6 * bundle.sample_rate) +) ###################################################################### # - -waveform, _ = torchaudio.load(speech_file, frame_offset=int(SAMPLE_RATE), num_frames=int(4.6 * SAMPLE_RATE)) - -emission = get_emission(waveform.to(device)) -num_frames = emission.size(1) +assert sample_rate == bundle.sample_rate ###################################################################### # transcript = text_normalized.split() -word_spans = compute_alignments(emission, transcript, dictionary) +emission, token_spans = compute_alignments(waveform, transcript) +num_frames = emission.size(1) + +plot_alignments(waveform, token_spans, emission, transcript) -plot_alignments(waveform, word_spans, emission, transcript) +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) +IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, transcript[0]) +preview_word(waveform, token_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript[1]) +preview_word(waveform, token_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript[2]) +preview_word(waveform, token_spans[2], num_frames, transcript[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, transcript[3]) +preview_word(waveform, token_spans[3], num_frames, transcript[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, transcript[4]) +preview_word(waveform, token_spans[4], num_frames, transcript[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, transcript[5]) +preview_word(waveform, token_spans[5], num_frames, transcript[5]) ###################################################################### # -preview_word(waveform, word_spans[6], num_frames, transcript[6]) +preview_word(waveform, token_spans[6], num_frames, transcript[6]) ###################################################################### # -preview_word(waveform, word_spans[7], num_frames, transcript[7]) +preview_word(waveform, token_spans[7], num_frames, transcript[7]) ###################################################################### # -preview_word(waveform, word_spans[8], num_frames, transcript[8]) +preview_word(waveform, token_spans[8], num_frames, transcript[8]) ###################################################################### # Italian # ~~~~~~~ -speech_file = torchaudio.utils.download_asset("tutorial-assets/642_529_000025.flac", progress=False) - text_raw = "elle giacean per terra tutte quante" text_normalized = "elle giacean per terra tutte quante" -print("Raw Transcript: ", text_raw) -print("Normalized Transcript: ", text_normalized) +url = "https://download.pytorch.org/torchaudio/tutorial-assets/642_529_000025.flac" +waveform, sample_rate = torchaudio.load(url, num_frames=int(4 * bundle.sample_rate)) ###################################################################### # - -waveform, _ = torchaudio.load(speech_file, num_frames=int(4 * SAMPLE_RATE)) - -emission = get_emission(waveform.to(device)) -num_frames = emission.size(1) +assert sample_rate == bundle.sample_rate ###################################################################### # transcript = text_normalized.split() -word_spans = compute_alignments(emission, transcript, dictionary) +emission, token_spans = compute_alignments(waveform, transcript) +num_frames = emission.size(1) -plot_alignments(waveform, word_spans, emission, transcript) +plot_alignments(waveform, token_spans, emission, transcript) + +print("Raw Transcript: ", text_raw) +print("Normalized Transcript: ", text_normalized) +IPython.display.Audio(waveform, rate=sample_rate) ###################################################################### # -preview_word(waveform, word_spans[0], num_frames, transcript[0]) +preview_word(waveform, token_spans[0], num_frames, transcript[0]) ###################################################################### # -preview_word(waveform, word_spans[1], num_frames, transcript[1]) +preview_word(waveform, token_spans[1], num_frames, transcript[1]) ###################################################################### # -preview_word(waveform, word_spans[2], num_frames, transcript[2]) +preview_word(waveform, token_spans[2], num_frames, transcript[2]) ###################################################################### # -preview_word(waveform, word_spans[3], num_frames, transcript[3]) +preview_word(waveform, token_spans[3], num_frames, transcript[3]) ###################################################################### # -preview_word(waveform, word_spans[4], num_frames, transcript[4]) +preview_word(waveform, token_spans[4], num_frames, transcript[4]) ###################################################################### # -preview_word(waveform, word_spans[5], num_frames, transcript[5]) +preview_word(waveform, token_spans[5], num_frames, transcript[5]) ###################################################################### # Conclusion @@ -664,7 +595,6 @@ def get_emission(waveform): # --------------- # # Thanks to `Vineel Pratap `__ and `Zhaoheng -# Ni `__ for working on the forced aligner API, and -# `Moto Hira `__ for providing alignment merging and -# visualization utilities. +# Ni `__ for developing and open-sourcing the +# forced aligner API. # diff --git a/test/integration_tests/prototype/vggish_pipeline_test.py b/test/integration_tests/prototype/vggish_pipeline_test.py index 72c6e1e518..942cadebdb 100644 --- a/test/integration_tests/prototype/vggish_pipeline_test.py +++ b/test/integration_tests/prototype/vggish_pipeline_test.py @@ -1,16 +1,19 @@ +import unittest + import torchaudio from torchaudio.prototype.pipelines import VGGISH -def test_vggish(): - input_sr = VGGISH.sample_rate - input_proc = VGGISH.get_input_processor() - model = VGGISH.get_model() - path = torchaudio.utils.download_asset("test-assets/Chopin_Ballade_-1_In_G_Minor,_Op._23_excerpt.mp3") - waveform, sr = torchaudio.load(path, backend="ffmpeg") - waveform = waveform.mean(axis=0) - waveform = torchaudio.functional.resample(waveform, sr, input_sr) - batch = input_proc(waveform) - assert batch.shape == (62, 1, 96, 64) - output = model(batch) - assert output.shape == (62, 128) +class VGGishPipelineTest(unittest.TestCase): + def test_vggish(self): + input_sr = VGGISH.sample_rate + input_proc = VGGISH.get_input_processor() + model = VGGISH.get_model() + path = torchaudio.utils.download_asset("test-assets/Chopin_Ballade_-1_In_G_Minor,_Op._23_excerpt.mp3") + waveform, sr = torchaudio.load(path, backend="ffmpeg") + waveform = waveform.mean(axis=0) + waveform = torchaudio.functional.resample(waveform, sr, input_sr) + batch = input_proc(waveform) + assert batch.shape == (62, 1, 96, 64) + output = model(batch) + assert output.shape == (62, 128) diff --git a/torchaudio/pipelines/_wav2vec2/aligner.py b/torchaudio/pipelines/_wav2vec2/aligner.py index 2b90f3eca5..26fe68161d 100644 --- a/torchaudio/pipelines/_wav2vec2/aligner.py +++ b/torchaudio/pipelines/_wav2vec2/aligner.py @@ -32,12 +32,12 @@ def __call__(self, transcript: List[str]) -> List[List[int]]: return [[self.dictionary[c] for c in word] for word in transcript] -def _align_emission_and_tokens(emission: Tensor, tokens: List[int]): +def _align_emission_and_tokens(emission: Tensor, tokens: List[int], blank: int = 0): device = emission.device emission = emission.unsqueeze(0) targets = torch.tensor([tokens], dtype=torch.int32, device=device) - aligned_tokens, scores = F.forced_align(emission, targets, 0) + aligned_tokens, scores = F.forced_align(emission, targets, blank=blank) scores = scores.exp() # convert back to probability aligned_tokens, scores = aligned_tokens[0], scores[0] # remove batch dimension @@ -50,7 +50,7 @@ def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[Token """Generate list of time-stamped token sequences Args: - emission (Tensor): Sequence of token probability distributions. + emission (Tensor): Sequence of token probability distributions in log-domain. Shape: `(time, tokens)`. tokens (list of integer sequence): Tokenized transcript. Output from :py:class:`Wav2Vec2FABundle.Tokenizer`. @@ -75,11 +75,13 @@ def _flatten(nested_list): class Aligner(IAligner): + def __init__(self, blank): + self.blank = blank + def __call__(self, emission: Tensor, tokens: List[List[int]]) -> List[List[TokenSpan]]: if emission.ndim != 2: raise ValueError(f"The input emission must be 2D. Found: {emission.shape}") - emission = torch.log_softmax(emission, dim=-1) - aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens)) + aligned_tokens, scores = _align_emission_and_tokens(emission, _flatten(tokens), self.blank) spans = F.merge_tokens(aligned_tokens, scores) return _unflatten(spans, [len(ts) for ts in tokens]) diff --git a/torchaudio/pipelines/_wav2vec2/impl.py b/torchaudio/pipelines/_wav2vec2/impl.py index 0fe663ed22..b7bdd7518f 100644 --- a/torchaudio/pipelines/_wav2vec2/impl.py +++ b/torchaudio/pipelines/_wav2vec2/impl.py @@ -1,4 +1,3 @@ -import copy from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple @@ -93,7 +92,7 @@ def get_model(self, *, dl_kwargs=None) -> Module: state_dict = self._get_state_dict(dl_kwargs) model.load_state_dict(state_dict) if self._normalize_waveform: - model = utils._apply_input_layer_norm(model) + model = utils._extend_model(model, normalize_waveform=True) model.eval() return model @@ -1587,11 +1586,6 @@ def get_labels(self, star: Optional[str] = "*", blank: str = "-") -> Tuple[str, labels = super().get_labels(blank=blank) return labels if star is None else (*labels, star) - def _get_params_with_star(self): - params = copy.deepcopy(self._params) - params["aux_num_out"] += 1 - return params - def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module: """Construct the model and load the pretrained weight. @@ -1605,13 +1599,19 @@ def get_model(self, with_star: bool = True, *, dl_kwargs=None) -> Module: Returns: Variation of :py:class:`~torchaudio.models.Wav2Vec2Model`. + + .. note:: + + The model created with this method returns probability in log-domain, + (i.e. :py:func:`torch.nn.functional.log_softmax` is applied), whereas + the other Wav2Vec2 models returns logit. """ - params = self._get_params_with_star() if with_star else self._params - model = utils._get_model(self._model_type, params) - state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis, with_star) + model = utils._get_model(self._model_type, self._params) + state_dict = utils._get_state_dict(self._path, dl_kwargs, self._remove_aux_axis) model.load_state_dict(state_dict) - if self._normalize_waveform: - model = utils._apply_input_layer_norm(model) + model = utils._extend_model( + model, normalize_waveform=self._normalize_waveform, apply_log_softmax=True, append_star=with_star + ) model.eval() return model @@ -1650,7 +1650,7 @@ def get_aligner(self) -> Aligner: Returns: Aligner """ - return aligner.Aligner() + return aligner.Aligner(blank=0) MMS_FA = Wav2Vec2FABundle( diff --git a/torchaudio/pipelines/_wav2vec2/utils.py b/torchaudio/pipelines/_wav2vec2/utils.py index 69e869208b..e690e8103c 100644 --- a/torchaudio/pipelines/_wav2vec2/utils.py +++ b/torchaudio/pipelines/_wav2vec2/utils.py @@ -24,13 +24,23 @@ class _Wav2Vec2Model(nn.Module): This is used for layer normalization at the input """ - def __init__(self, model: Wav2Vec2Model): + def __init__(self, model: Wav2Vec2Model, normalize_waveform: bool, apply_log_softmax: bool, append_star: bool): super().__init__() self.model = model + self.normalize_waveform = normalize_waveform + self.apply_log_softmax = apply_log_softmax + self.append_star = append_star def forward(self, waveforms: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: - waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) - return self.model(waveforms, lengths) + if self.normalize_waveform: + waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) + output, output_lengths = self.model(waveforms, lengths) + if self.apply_log_softmax: + output = torch.nn.functional.log_softmax(output, dim=-1) + if self.append_star: + star_dim = torch.zeros((1, output.size(1), 1), dtype=output.dtype, device=output.device) + output = torch.cat((output, star_dim), dim=-1) + return output, output_lengths @torch.jit.export def extract_features( @@ -39,13 +49,14 @@ def extract_features( lengths: Optional[Tensor] = None, num_layers: Optional[int] = None, ) -> Tuple[List[Tensor], Optional[Tensor]]: - waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) + if self.normalize_waveform: + waveforms = nn.functional.layer_norm(waveforms, waveforms.shape) return self.model.extract_features(waveforms, lengths, num_layers) -def _apply_input_layer_norm(module): - """Add extra layer_norm to the model""" - return _Wav2Vec2Model(module) +def _extend_model(module, normalize_waveform, apply_log_softmax=False, append_star=False): + """Add extra transformations to the model""" + return _Wav2Vec2Model(module, normalize_waveform, apply_log_softmax, append_star) def _remove_aux_axes(state_dict, axes): @@ -65,23 +76,13 @@ def _remove_aux_axes(state_dict, axes): state_dict[key] = torch.stack([mat[i] for i in range(mat.size(0)) if i not in axes]) -def _add_star_dim(state_dict): - w, b = state_dict["aux.weight"], state_dict["aux.bias"] - zeros = torch.zeros((1, w.size(1)), device=w.device, dtype=w.dtype) - state_dict["aux.weight"] = torch.cat((zeros, w), dim=0) - ones = torch.ones((1,), device=b.device, dtype=b.dtype) - state_dict["aux.bias"] = torch.cat((b, ones), dim=0) - - -def _get_state_dict(url, dl_kwargs, remove_axes=None, add_star=False): +def _get_state_dict(url, dl_kwargs, remove_axes=None): if not url.startswith("https"): url = f"https://download.pytorch.org/torchaudio/models/{url}" dl_kwargs = {} if dl_kwargs is None else dl_kwargs state_dict = load_state_dict_from_url(url, **dl_kwargs) if remove_axes: _remove_aux_axes(state_dict, remove_axes) - if add_star: - _add_star_dim(state_dict) return state_dict