Skip to content

Commit

Permalink
fixing asr
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Oct 22, 2024
1 parent 7547d34 commit 58170a8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 82 deletions.
2 changes: 1 addition & 1 deletion batchalign/cli/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _dispatch(command, lang, num_speakers,
for basedir, _, fs in os.walk(in_dir):
for f in fs:
path = Path(os.path.join(basedir, f))
ext = path.suffix.strip(".").strip()
ext = path.suffix.strip(".").strip().lower()

# calculate input path, convert if needed
inp_path = str(path)
Expand Down
89 changes: 10 additions & 79 deletions batchalign/models/whisper/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,85 +67,16 @@ def __init__(self, model, base="openai/whisper-large-v3", language="english", ta
self.__config = GenerationConfig.from_pretrained(base)
self.__config.no_repeat_ngram_size = 4

if language == "Cantonese":
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
# tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=30,
# stride_length_s=3,
device=DEVICE,
# torch_dtype=torch.float32,
return_timestamps="word",
)
self.__config = GenerationConfig.from_model_config(self.pipe.model.config)
self.__config.no_repeat_ngram_size = 4
self.__config.use_cache = False

forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language="yue", task="transcribe")

suppress_tokens = []

# Define other parameters
return_attention_mask = False
pad_token_id = 50257
bos_token_id = 50257
eos_token_id = 50257
decoder_start_token_id = 50258
begin_suppress_tokens = [
220,
50257
],
alignment_heads = [
[5, 3],
[5, 9],
[8, 0],
[8, 4],
[8, 8],
[9, 0],
[9, 7],
[9, 9],
[10, 5]
]
lang_to_id = {"<|yue|>": 50325}
task_to_id = {"transcribe": 50359}
is_multilingual = True
max_initial_timestamp_index = 50
no_timestamps_token_id = 50363
prev_sot_token_id = 50361
max_length = 448

# Assign values to generation config
self.__config.forced_decoder_ids = forced_decoder_ids
self.__config.suppress_tokens = suppress_tokens
self.__config.pad_token_id = pad_token_id
self.__config.bos_token_id = bos_token_id
self.__config.eos_token_id = eos_token_id
self.__config.decoder_start_token_id = decoder_start_token_id
self.__config.lang_to_id = lang_to_id
self.__config.task_to_id = task_to_id
self.__config.alignment_heads = alignment_heads
self.__config.alignment_heads = alignment_heads
self.__config.begin_suppress_tokens = begin_suppress_tokens
self.__config.is_multilingual = is_multilingual
self.__config.max_initial_timestamp_index = max_initial_timestamp_index
self.__config.no_timestamps_token_id = no_timestamps_token_id
self.__config.prev_sot_token_id = prev_sot_token_id
self.__config.max_length =max_length

self.pipe.model.generation_config = self.__config

else:
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=25,
stride_length_s=3,
device=DEVICE,
torch_dtype=torch.float32,
return_timestamps="word",
)
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=25,
stride_length_s=3,
device=DEVICE,
torch_dtype=torch.float32,
return_timestamps="word",
)
L.debug("Done, initalizing processor and config...")
processor = WhisperProcessor.from_pretrained(base)
L.debug("Whisper initialization done.")
Expand Down
4 changes: 2 additions & 2 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.6-alpha.15
0.7.6-alpha.16
October 16, 2024
gerund support
fixing asr for file names

0 comments on commit 58170a8

Please sign in to comment.