From 44dad22cbe0e43b8ca552d5908957ce0a2a573b5 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Tue, 6 Feb 2024 18:20:40 +0800 Subject: [PATCH] [w2vbert] support w2vbert fbank (#2346) * [ssl/w2vbert] support w2vbert fbank * [ssl/w2vbert] add libsndfile in ut yaml --- .github/workflows/unit_test.yml | 2 +- test/wenet/ssl/w2vbert/test_w2vbert.py | 50 ++++++++++++++++++++++++++ wenet/dataset/processor.py | 16 +++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 test/wenet/ssl/w2vbert/test_w2vbert.py diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index caa8b199b..1f48d3210 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -40,7 +40,7 @@ jobs: run: | set -eux pip install -r requirements.txt - sudo apt update && sudo apt install -y ffmpeg libsox-dev + sudo apt update && sudo apt install -y ffmpeg libsox-dev libsndfile1 - name: Run Pytest run: | set -eux diff --git a/test/wenet/ssl/w2vbert/test_w2vbert.py b/test/wenet/ssl/w2vbert/test_w2vbert.py new file mode 100644 index 000000000..e26052d23 --- /dev/null +++ b/test/wenet/ssl/w2vbert/test_w2vbert.py @@ -0,0 +1,50 @@ +from pathlib import Path +import pytest +import torch +import torchaudio + +from wenet.dataset import processor + +try: + import fairseq2 # noqa + from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter + from fairseq2.memory import MemoryBlock +except ImportError: + import os + os.system('pip install --no-input fairseq2') + import fairseq2 # noqa + from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter + from fairseq2.memory import MemoryBlock + + +@pytest.mark.parametrize( + "wav_file", + [ + # "test/resources/aishell-BAC009S0724W0121.wav", + "test/resources/librispeech-1995-1837-0001.wav", + ]) +def test_w2vbert_fbank(wav_file): + fbank_convert = WaveformToFbankConverter( + num_mel_bins=80, + waveform_scale=2**15, + channel_last=True, + standardize=True, + ) + audio_decoder = AudioDecoder(dtype=torch.float32) + with Path(wav_file).open("rb") as fb: + block = MemoryBlock(fb.read()) + decode_audio = audio_decoder(block) + w2vbert_waveform = decode_audio['waveform'] + w2vbert_mat = fbank_convert(decode_audio)['fbank'] + + wenet_waveform, _ = torchaudio.load(wav_file) + fbank_args = { + "num_mel_bins": 80, + "frame_length": 25, + "frame_shift": 10, + "dither": 0.0, + } + sample = {'sample_rate': 16000, "wav": wenet_waveform, 'key': wav_file} + wenet_mat = processor.compute_w2vbert_fbank(sample, **fbank_args)['feat'] + assert torch.allclose(w2vbert_waveform.transpose(0, 1), wenet_waveform) + assert torch.allclose(w2vbert_mat, wenet_mat, atol=9e-5, rtol=9e-4) diff --git a/wenet/dataset/processor.py b/wenet/dataset/processor.py index e69165c9e..65c319479 100644 --- a/wenet/dataset/processor.py +++ b/wenet/dataset/processor.py @@ -222,6 +222,22 @@ def compute_fbank(sample, return sample +def compute_w2vbert_fbank(sample, + num_mel_bins=23, + frame_length=25, + frame_shift=10, + dither=0.0): + """ Extract Pretrain w2vbert(4.5M hours) fbank + """ + sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift, + dither) + mat = sample['feat'] + std, mean = torch.std_mean(mat, dim=0) + mat = mat.subtract(mean).divide(std) + sample['feat'] = mat + return sample + + def sort_by_feats(sample): assert 'feat' in sample assert isinstance(sample['feat'], torch.Tensor)