Skip to content

Commit

Permalink
Merge branch 'feat/cantonese'
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Oct 4, 2024
2 parents 7b31109 + bc0e130 commit 4320c05
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 20 deletions.
115 changes: 98 additions & 17 deletions batchalign/models/whisper/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,89 @@ class WhisperASRModel(object):

def __init__(self, model, base="openai/whisper-large-v3", language="english", target_sample_rate=16000):
L.debug("Initializing whisper model...")
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=25,
stride_length_s=3,
device=DEVICE,
torch_dtype=torch.float32,
return_timestamps="word",
)
L.debug("Done, initalizing processor and config...")
self.__config = GenerationConfig.from_pretrained(base)
self.__config.no_repeat_ngram_size = 4

if language == "Cantonese":
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
# tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=30,
# stride_length_s=3,
device=DEVICE,
# torch_dtype=torch.float32,
return_timestamps="word",
)
self.__config = GenerationConfig.from_model_config(self.pipe.model.config)
self.__config.no_repeat_ngram_size = 4
self.__config.use_cache = False

forced_decoder_ids = self.pipe.tokenizer.get_decoder_prompt_ids(language="yue", task="transcribe")

suppress_tokens = []

# Define other parameters
return_attention_mask = False
pad_token_id = 50257
bos_token_id = 50257
eos_token_id = 50257
decoder_start_token_id = 50258
begin_suppress_tokens = [
220,
50257
],
alignment_heads = [
[5, 3],
[5, 9],
[8, 0],
[8, 4],
[8, 8],
[9, 0],
[9, 7],
[9, 9],
[10, 5]
]
lang_to_id = {"<|yue|>": 50325}
task_to_id = {"transcribe": 50359}
is_multilingual = True
max_initial_timestamp_index = 50
no_timestamps_token_id = 50363
prev_sot_token_id = 50361
max_length = 448

# Assign values to generation config
self.__config.forced_decoder_ids = forced_decoder_ids
self.__config.suppress_tokens = suppress_tokens
self.__config.pad_token_id = pad_token_id
self.__config.bos_token_id = bos_token_id
self.__config.eos_token_id = eos_token_id
self.__config.decoder_start_token_id = decoder_start_token_id
self.__config.lang_to_id = lang_to_id
self.__config.task_to_id = task_to_id
self.__config.alignment_heads = alignment_heads
self.__config.alignment_heads = alignment_heads
self.__config.begin_suppress_tokens = begin_suppress_tokens
self.__config.is_multilingual = is_multilingual
self.__config.max_initial_timestamp_index = max_initial_timestamp_index
self.__config.no_timestamps_token_id = no_timestamps_token_id
self.__config.prev_sot_token_id = prev_sot_token_id
self.__config.max_length =max_length

self.pipe.model.generation_config = self.__config

else:
self.pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=WhisperTokenizer.from_pretrained(base),
chunk_length_s=25,
stride_length_s=3,
device=DEVICE,
torch_dtype=torch.float32,
return_timestamps="word",
)
L.debug("Done, initalizing processor and config...")
processor = WhisperProcessor.from_pretrained(base)
L.debug("Whisper initialization done.")

Expand Down Expand Up @@ -147,14 +217,25 @@ def __call__(self, data, segments=None):
})

L.debug("Whisper Transcribing...")
config = {
"repetition_penalty": 1.001,
"generation_config": self.__config,
"task": "transcribe",
"language": self.lang
}

if self.lang == "Cantonese":
config = {
"repetition_penalty": 1.001,
# "generation_config": self.__config,
# "task": "transcribe",
# "language": self.lang
}

words = self.pipe(data.cpu().numpy(),
batch_size=1,
generate_kwargs = {
"repetition_penalty": 1.001,
"generation_config": self.__config,
"task": "transcribe",
"language": self.lang
})
generate_kwargs=config)

# "do_sample": True,
# "temperature": 0.1
# })
Expand Down
6 changes: 3 additions & 3 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.6-alpha.5
October 3rd, 2024
features for aux
0.7.6-alpha.6
October 3, 2024
default to whisper small for cantonese (try 2)

0 comments on commit 4320c05

Please sign in to comment.