diff --git a/faster_whisper/transcribe.py b/faster_whisper/transcribe.py index f389f09f..52dac676 100644 --- a/faster_whisper/transcribe.py +++ b/faster_whisper/transcribe.py @@ -7,7 +7,7 @@ from dataclasses import asdict, dataclass from inspect import signature from math import ceil -from typing import BinaryIO, Iterable, List, Optional, Tuple, Union +from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Union from warnings import warn import ctranslate2 @@ -81,11 +81,11 @@ class TranscriptionOptions: compression_ratio_threshold: Optional[float] condition_on_previous_text: bool prompt_reset_on_temperature: float - temperatures: List[float] + temperatures: Union[List[float], Tuple[float, ...]] initial_prompt: Optional[Union[str, Iterable[int]]] prefix: Optional[str] suppress_blank: bool - suppress_tokens: Optional[List[int]] + suppress_tokens: Union[List[int], Tuple[int, ...]] without_timestamps: bool max_initial_timestamp: float word_timestamps: bool @@ -106,7 +106,7 @@ class TranscriptionInfo: duration_after_vad: float all_language_probs: Optional[List[Tuple[str, float]]] transcription_options: TranscriptionOptions - vad_options: VadOptions + vad_options: Optional[VadOptions] class BatchedInferencePipeline: @@ -121,7 +121,6 @@ def forward(self, features, tokenizer, chunks_metadata, options): encoder_output, outputs = self.generate_segment_batched( features, tokenizer, options ) - segmented_outputs = [] segment_sizes = [] for chunk_metadata, output in zip(chunks_metadata, outputs): @@ -130,8 +129,8 @@ def forward(self, features, tokenizer, chunks_metadata, options): segment_sizes.append(segment_size) ( subsegments, - seek, - single_timestamp_ending, + _, + _, ) = self.model._split_segments_by_timestamps( tokenizer=tokenizer, tokens=output["tokens"], @@ -295,7 +294,7 @@ def transcribe( hallucination_silence_threshold: Optional[float] = None, batch_size: int = 8, hotwords: Optional[str] = None, - language_detection_threshold: Optional[float] = 0.5, + language_detection_threshold: float = 0.5, language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """transcribe audio in chunks in batched fashion and return with language info. @@ -582,7 +581,7 @@ def __init__( num_workers: int = 1, download_root: Optional[str] = None, local_files_only: bool = False, - files: dict = None, + files: Optional[dict] = None, **model_kwargs, ): """Initializes the Whisper model. @@ -731,7 +730,7 @@ def transcribe( clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, hotwords: Optional[str] = None, - language_detection_threshold: Optional[float] = 0.5, + language_detection_threshold: float = 0.5, language_detection_segments: int = 1, ) -> Tuple[Iterable[Segment], TranscriptionInfo]: """Transcribes an input file. @@ -833,7 +832,7 @@ def transcribe( elif isinstance(vad_parameters, dict): vad_parameters = VadOptions(**vad_parameters) speech_chunks = get_speech_timestamps(audio, vad_parameters) - audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio_chunks, _ = collect_chunks(audio, speech_chunks) audio = np.concatenate(audio_chunks, axis=0) duration_after_vad = audio.shape[0] / sampling_rate @@ -925,7 +924,7 @@ def transcribe( condition_on_previous_text=condition_on_previous_text, prompt_reset_on_temperature=prompt_reset_on_temperature, temperatures=( - temperature if isinstance(temperature, (list, tuple)) else [temperature] + temperature if isinstance(temperature, (List, Tuple)) else [temperature] ), initial_prompt=initial_prompt, prefix=prefix, @@ -953,7 +952,8 @@ def transcribe( if speech_chunks: segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate) - + if isinstance(vad_parameters, dict): + vad_parameters = VadOptions(**vad_parameters) info = TranscriptionInfo( language=language, language_probability=language_probability, @@ -974,7 +974,7 @@ def _split_segments_by_timestamps( segment_size: int, segment_duration: float, seek: int, - ) -> List[List[int]]: + ) -> Tuple[List[Any], int, bool]: current_segments = [] single_timestamp_ending = ( len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1] @@ -1517,8 +1517,8 @@ def add_word_timestamps( num_frames: int, prepend_punctuations: str, append_punctuations: str, - last_speech_timestamp: float, - ) -> float: + last_speech_timestamp: Union[float, None], + ) -> Optional[float]: if len(segments) == 0: return @@ -1665,9 +1665,11 @@ def find_alignment( text_indices = np.array([pair[0] for pair in alignments]) time_indices = np.array([pair[1] for pair in alignments]) - words, word_tokens = tokenizer.split_to_word_tokens( - text_token + [tokenizer.eot] - ) + if isinstance(text_token, int): + tokens = [text_token] + [tokenizer.eot] + else: + tokens = text_token + [tokenizer.eot] + words, word_tokens = tokenizer.split_to_word_tokens(tokens) if len(word_tokens) <= 1: # return on eot only # >>> np.pad([], (1, 0)) @@ -1715,7 +1717,7 @@ def detect_language( audio: Optional[np.ndarray] = None, features: Optional[np.ndarray] = None, vad_filter: bool = False, - vad_parameters: Union[dict, VadOptions] = None, + vad_parameters: Optional[Union[dict, VadOptions]] = None, language_detection_segments: int = 1, language_detection_threshold: float = 0.5, ) -> Tuple[str, float, List[Tuple[str, float]]]: @@ -1747,18 +1749,24 @@ def detect_language( if audio is not None: if vad_filter: speech_chunks = get_speech_timestamps(audio, vad_parameters) - audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks) + audio_chunks, _ = collect_chunks(audio, speech_chunks) audio = np.concatenate(audio_chunks, axis=0) - + assert ( + audio is not None + ), "Audio have a problem while concatanating the audio_chunks; return None" audio = audio[ : language_detection_segments * self.feature_extractor.n_samples ] features = self.feature_extractor(audio) - + assert ( + features is not None + ), "No features extracted from audio file; return None" features = features[ ..., : language_detection_segments * self.feature_extractor.nb_max_frames ] - + assert ( + features is not None + ), "No features extracted when detectting language in audio segments; return None" detected_language_info = {} for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames): encoder_output = self.encode( @@ -1828,13 +1836,13 @@ def get_compression_ratio(text: str) -> float: def get_suppressed_tokens( tokenizer: Tokenizer, - suppress_tokens: Tuple[int], -) -> Optional[List[int]]: - if -1 in suppress_tokens: + suppress_tokens: Optional[List[int]], +) -> Tuple[int, ...]: + if suppress_tokens is None or len(suppress_tokens) == 0: + suppress_tokens = [] # interpret empty string as an empty list + elif -1 in suppress_tokens: suppress_tokens = [t for t in suppress_tokens if t >= 0] suppress_tokens.extend(tokenizer.non_speech_tokens) - elif suppress_tokens is None or len(suppress_tokens) == 0: - suppress_tokens = [] # interpret empty string as an empty list else: assert isinstance(suppress_tokens, list), "suppress_tokens must be a list" diff --git a/faster_whisper/vad.py b/faster_whisper/vad.py index 3bcca221..e5ef1964 100644 --- a/faster_whisper/vad.py +++ b/faster_whisper/vad.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import numpy as np @@ -44,7 +44,7 @@ class VadOptions: def get_speech_timestamps( audio: np.ndarray, - vad_options: Optional[VadOptions] = None, + vad_options: Optional[Union[dict, VadOptions]] = None, sampling_rate: int = 16000, **kwargs, ) -> List[dict]: @@ -61,7 +61,8 @@ def get_speech_timestamps( """ if vad_options is None: vad_options = VadOptions(**kwargs) - + if isinstance(vad_options, dict): + vad_options = VadOptions(**vad_options) onset = vad_options.onset min_speech_duration_ms = vad_options.min_speech_duration_ms max_speech_duration_s = vad_options.max_speech_duration_s