Skip to content

Commit

Permalink
[w2vbert] support w2vbert fbank (wenet-e2e#2346)
Browse files Browse the repository at this point in the history
* [ssl/w2vbert] support w2vbert fbank

* [ssl/w2vbert] add libsndfile in ut yaml
  • Loading branch information
Mddct authored Feb 6, 2024
1 parent ad5b3b6 commit 44dad22
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
run: |
set -eux
pip install -r requirements.txt
sudo apt update && sudo apt install -y ffmpeg libsox-dev
sudo apt update && sudo apt install -y ffmpeg libsox-dev libsndfile1
- name: Run Pytest
run: |
set -eux
Expand Down
50 changes: 50 additions & 0 deletions test/wenet/ssl/w2vbert/test_w2vbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path
import pytest
import torch
import torchaudio

from wenet.dataset import processor

try:
import fairseq2 # noqa
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.memory import MemoryBlock
except ImportError:
import os
os.system('pip install --no-input fairseq2')
import fairseq2 # noqa
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
from fairseq2.memory import MemoryBlock


@pytest.mark.parametrize(
"wav_file",
[
# "test/resources/aishell-BAC009S0724W0121.wav",
"test/resources/librispeech-1995-1837-0001.wav",
])
def test_w2vbert_fbank(wav_file):
fbank_convert = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=True,
)
audio_decoder = AudioDecoder(dtype=torch.float32)
with Path(wav_file).open("rb") as fb:
block = MemoryBlock(fb.read())
decode_audio = audio_decoder(block)
w2vbert_waveform = decode_audio['waveform']
w2vbert_mat = fbank_convert(decode_audio)['fbank']

wenet_waveform, _ = torchaudio.load(wav_file)
fbank_args = {
"num_mel_bins": 80,
"frame_length": 25,
"frame_shift": 10,
"dither": 0.0,
}
sample = {'sample_rate': 16000, "wav": wenet_waveform, 'key': wav_file}
wenet_mat = processor.compute_w2vbert_fbank(sample, **fbank_args)['feat']
assert torch.allclose(w2vbert_waveform.transpose(0, 1), wenet_waveform)
assert torch.allclose(w2vbert_mat, wenet_mat, atol=9e-5, rtol=9e-4)
16 changes: 16 additions & 0 deletions wenet/dataset/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,22 @@ def compute_fbank(sample,
return sample


def compute_w2vbert_fbank(sample,
num_mel_bins=23,
frame_length=25,
frame_shift=10,
dither=0.0):
""" Extract Pretrain w2vbert(4.5M hours) fbank
"""
sample = compute_fbank(sample, num_mel_bins, frame_length, frame_shift,
dither)
mat = sample['feat']
std, mean = torch.std_mean(mat, dim=0)
mat = mat.subtract(mean).divide(std)
sample['feat'] = mat
return sample


def sort_by_feats(sample):
assert 'feat' in sample
assert isinstance(sample['feat'], torch.Tensor)
Expand Down

0 comments on commit 44dad22

Please sign in to comment.