-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinfer.py
51 lines (40 loc) · 1.85 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import os
import torchaudio
from nemo.collections.asr.models import EncDecCTCModel
from tqdm import tqdm
from utils import stt
# python infer.py --checkpoint ./checkpoints/exp0/last.ckpt --audio_dir 'I:/tts/3arabiyya/arabic-speech-corpus/test set/wav'
# python infer.py --checkpoint ./checkpoints/exp0/states_val_loss=19.79108.ckpt --audio_dir data/test_waves
def infer(args):
model = EncDecCTCModel.load_from_checkpoint(
args.checkpoint,
hparams_file=args.hparams_file)
model = model.to(args.device)
model.eval()
audio_fpaths = [f.path for f in os.scandir(args.audio_dir) if f.path.endswith('wav')]
print(f"Found {len(audio_fpaths)} audio files @ {args.audio_dir}")
with open(args.output_file_path, 'w', encoding='utf-8') as f:
for audio_fpath in tqdm(audio_fpaths):
fname = os.path.basename(audio_fpath)
audio, sr = torchaudio.load(audio_fpath)
audio = audio.to(args.device)
if sr != 16000:
audio = torchaudio.functional.resample(audio, sr, 16000)
audio /= audio.abs().max()
audio = audio[0]
text_pred = stt(model, audio)
f.write(f"{fname}\t{text_pred}\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str,
default='./checkpoints/exp0/last.ckpt')
parser.add_argument('--hparams_file', type=str,
default='./logs/exp0/hparams.yaml')
parser.add_argument('--audio_dir', type=str,
default='./data/test_waves')
parser.add_argument('--output_file_path', type=str,
default='./data/infer_text.txt')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
infer(args)