forked from justinwlin/runpodWhisperx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
load_align_model.py
52 lines (47 loc) · 2.12 KB
/
load_align_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import sys
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
lang = sys.argv[1]
# https://github.com/m-bain/whisperX/blob/v3.1.1/whisperx/alignment.py#L21
DEFAULT_ALIGN_MODELS_TORCH = {
"en": "WAV2VEC2_ASR_BASE_960H",
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
"de": "VOXPOPULI_ASR_BASE_10K_DE",
"es": "VOXPOPULI_ASR_BASE_10K_ES",
"it": "VOXPOPULI_ASR_BASE_10K_IT",
}
DEFAULT_ALIGN_MODELS_HF = {
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
"cs": "comodoro/wav2vec2-xls-r-300m-cs-250",
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
"da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech",
"he": "imvladikon/wav2vec2-xls-r-300m-hebrew",
"vi": "nguyenvulebinh/wav2vec2-base-vi",
"ko": "kresnik/wav2vec2-large-xlsr-korean",
"ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu",
"te": "anuragshas/wav2vec2-large-xlsr-53-telugu",
"hi": "theainerd/Wav2Vec2-large-xlsr-hindi",
}
# From https://github.com/m-bain/whisperX/issues/189#issuecomment-1523392800
if lang in DEFAULT_ALIGN_MODELS_TORCH:
model_name = DEFAULT_ALIGN_MODELS_TORCH[lang]
bundle = torchaudio.pipelines.__dict__[model_name]
align_model = bundle.get_model()
labels = bundle.get_labels()
elif lang in DEFAULT_ALIGN_MODELS_HF:
model_name = DEFAULT_ALIGN_MODELS_HF[lang]
processor = Wav2Vec2Processor.from_pretrained(model_name)
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
else:
raise ValueError(f"Unsupported language: {lang}")