From 27aa6cfcee1fac0e117a9b6e21ada8d999eaf9d7 Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 31 Oct 2024 16:04:58 +0800 Subject: [PATCH 1/2] [cli] paraformer support batch infer --- wenet/cli/paraformer_model.py | 101 +++++++++++++++++++++------------- 1 file changed, 64 insertions(+), 37 deletions(-) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index a4f834ab2..0e9f7f305 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -1,14 +1,14 @@ +import io import os +from typing import Dict, List, Union import torch import torchaudio import torchaudio.compliance.kaldi as kaldi - from wenet.cli.hub import Hub from wenet.paraformer.search import (gen_timestamps_from_peak, paraformer_greedy_search) from wenet.text.paraformer_tokenizer import ParaformerTokenizer -from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu class Paraformer: @@ -22,46 +22,73 @@ def __init__(self, model_dir: str, resample_rate: int = 16000) -> None: self.device = torch.device("cpu") self.tokenizer = ParaformerTokenizer(symbol_table=units_path) - def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: - waveform, sample_rate = torchaudio.load(audio_file, normalize=False) - waveform = waveform.to(torch.float).to(self.device) - if sample_rate != self.resample_rate: - waveform = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) - feats = kaldi.fbank(waveform, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - energy_floor=0.0, - sample_frequency=self.resample_rate, - window_type="hamming") - feats = feats.unsqueeze(0) - feats_lens = torch.tensor([feats.size(1)], - dtype=torch.int64, - device=feats.device) + @torch.inference_mode() + def transcribe_batch(self, + audio_files: List[Union[str, bytes]], + tokens_info: bool = False) -> List[Dict]: + feats_lst = [] + feats_lens_lst = [] + for audio in audio_files: + if isinstance(audio, bytes): + with io.BytesIO(audio) as fobj: + waveform, sample_rate = torchaudio.load(fobj, + normalize=False) + else: + waveform, sample_rate = torchaudio.load(audio, normalize=False) + if sample_rate != self.resample_rate: + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, + new_freq=self.resample_rate)(waveform) + + waveform = waveform.to(torch.float).to(self.device) + feats = kaldi.fbank(waveform, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + energy_floor=0.0, + sample_frequency=self.resample_rate, + window_type="hamming") + feats_lst.append(feats) + feats_lens_lst.append( + torch.tensor(feats.shape[0], dtype=torch.int64)) + feats_tensor = torch.nn.utils.rnn.pad_sequence( + feats_lst, batch_first=True).to(device=self.device) + feats_lens_tensor = torch.tensor(feats_lens_lst) decoder_out, token_num, tp_alphas = self.model.forward_paraformer( - feats, feats_lens) + feats_tensor, feats_lens_tensor) cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) - res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0] - result = {} - result['confidence'] = res.confidence - result['text'] = self.tokenizer.detokenize(res.tokens)[0] - if tokens_info: - tokens_info = [] - times = gen_timestamps_from_peak(res.times, - num_frames=tp_alphas.size(1), - frame_rate=0.02) - for i, x in enumerate(res.tokens): - tokens_info.append({ - 'token': self.tokenizer.char_dict[x], - 'start': round(times[i][0], 3), - 'end': round(times[i][1], 3), - 'confidence': round(res.tokens_confidence[i], 2) - }) - result['tokens'] = tokens_info + results = paraformer_greedy_search(decoder_out, token_num, cif_peaks) + r = [] + for res in results: + result = {} + result['confidence'] = res.confidence + result['text'] = self.tokenizer.detokenize(res.tokens)[0] + if tokens_info: + tokens_info_l = [] + times = gen_timestamps_from_peak(res.times, + num_frames=tp_alphas.size(1), + frame_rate=0.02) + + for i, x in enumerate(res.tokens): + tokens_info_l.append({ + 'token': + self.tokenizer.char_dict[x], + 'start': + round(times[i][0], 3), + 'end': + round(times[i][1], 3), + 'confidence': + round(res.tokens_confidence[i], 2) + }) + result['tokens'] = tokens_info_l + r.append(result) + return r + + def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: + result = self.transcribe_batch([audio_file], tokens_info)[0] return result def align(self, audio_file: str, label: str) -> dict: From ce6c3d35e509dfea5f7b53dd3d7d2180e3f7cc15 Mon Sep 17 00:00:00 2001 From: Mddct Date: Thu, 31 Oct 2024 16:15:50 +0800 Subject: [PATCH 2/2] fix device --- wenet/cli/paraformer_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index 0e9f7f305..20233255a 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -40,7 +40,7 @@ def transcribe_batch(self, orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) - waveform = waveform.to(torch.float).to(self.device) + waveform = waveform.to(torch.float) feats = kaldi.fbank(waveform, num_mel_bins=80, frame_length=25, @@ -53,7 +53,7 @@ def transcribe_batch(self, torch.tensor(feats.shape[0], dtype=torch.int64)) feats_tensor = torch.nn.utils.rnn.pad_sequence( feats_lst, batch_first=True).to(device=self.device) - feats_lens_tensor = torch.tensor(feats_lens_lst) + feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device) decoder_out, token_num, tp_alphas = self.model.forward_paraformer( feats_tensor, feats_lens_tensor)