From b2b777cbd2712e651318e97f0449ae9f2930803b Mon Sep 17 00:00:00 2001 From: KoljaB Date: Fri, 20 Dec 2024 11:42:52 +0100 Subject: [PATCH] improvements to algo --- ...imestt_speechendpoint_binary_classified.py | 224 ++++++++++++++---- 1 file changed, 175 insertions(+), 49 deletions(-) diff --git a/tests/realtimestt_speechendpoint_binary_classified.py b/tests/realtimestt_speechendpoint_binary_classified.py index 01fac23..c07d001 100644 --- a/tests/realtimestt_speechendpoint_binary_classified.py +++ b/tests/realtimestt_speechendpoint_binary_classified.py @@ -1,5 +1,8 @@ #IS_DEBUG = True IS_DEBUG = False +USE_STEREO_MIX = False +LOOPBACK_DEVICE_NAME = "stereomix" +LOOPBACK_DEVICE_HOST_API = 0 import os import re @@ -21,6 +24,34 @@ ]) EXTENDED_LOGGING = False +sentence_end_marks = ['.', '!', '?', '。'] + + +detection_speed = 0.5 # set detection speed between 0.1 and 2.0 + + + +if detection_speed < 0.1: + detection_speed = 0.1 +if detection_speed > 2.0: + detection_speed = 2.0 + +last_detection_pause = 0 +last_prob_complete = 0 +last_suggested_pause = 0 +last_pause = 0 +end_of_sentence_detection_pause = 0.3 +maybe_end_of_sentence_detection_pause = 0.5 +unknown_sentence_detection_pause = 0.8 +ellipsis_pause = 1.7 +punctuation_pause = 0.5 +exclamation_pause = 0.4 +question_pause = 0.3 + +hard_break_even_on_background_noise = 3.0 +hard_break_even_on_background_noise_min_texts = 3 +hard_break_even_on_background_noise_min_chars = 15 +hard_break_even_on_background_noise_min_similarity = 0.99 if __name__ == '__main__': @@ -50,8 +81,6 @@ tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir) classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir) - # tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir, force_download=True) - # classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir, force_download=True) classification_model.to(device) classification_model.eval() @@ -88,13 +117,6 @@ def get_completion_probability(sentence, model, tokenizer, device, max_length): (0.0, 1.0), (1.0, 0) ] - # anchor_points = [ - # (0.0, 0.4), - # (0.5, 0.3), - # (0.8, 0.2), - # (0.9, 0.1), - # (1.0, 0) - # ] def interpolate_detection(prob): # Clamp probability between 0.0 and 1.0 just in case @@ -144,17 +166,8 @@ def is_speech_finished(text): recorder = None displayed_text = "" text_time_deque = deque() - - # Default values - end_of_sentence_detection_pause = 0.3 - unknown_sentence_detection_pause = 0.8 - mid_sentence_detection_pause = 1.7 - hard_break_even_on_background_noise = 3.0 - hard_break_even_on_background_noise_min_texts = 3 - hard_break_even_on_background_noise_min_chars = 15 - hard_break_even_on_background_noise_min_similarity = 0.99 + texts_without_punctuation = [] relisten_on_abrupt_stop = True - abrupt_stop = False prev_text = "" @@ -170,6 +183,18 @@ def preprocess_text(text): def text_detected(text): text_queue.put(text) + def ends_with_string(text: str, s: str): + if text.endswith(s): + return True + if len(text) > 1 and text[:-1].endswith(s): + return True + return False + + def sentence_end(text: str): + if text and text[-1] in sentence_end_marks: + return True + return False + def additional_pause_based_on_words(text): word_count = len(text.split()) pauses = { @@ -182,10 +207,78 @@ def additional_pause_based_on_words(text): 6: 0.05, } return pauses.get(word_count, 0.0) + + def strip_ending_punctuation(text): + """Remove trailing periods and ellipses from text.""" + text = text.rstrip() + for char in sentence_end_marks: + text = text.rstrip(char) + return text + + def get_suggested_whisper_pause(text): + if ends_with_string(text, "..."): + return ellipsis_pause + elif ends_with_string(text, "."): + return punctuation_pause + elif ends_with_string(text, "!"): + return exclamation_pause + elif ends_with_string(text, "?"): + return question_pause + else: + return unknown_sentence_detection_pause + + def find_stereo_mix_index(): + import pyaudio + audio = pyaudio.PyAudio() + devices_info = "" + for i in range(audio.get_device_count()): + dev = audio.get_device_info_by_index(i) + devices_info += f"{dev['index']}: {dev['name']} (hostApi: {dev['hostApi']})\n" + + if (LOOPBACK_DEVICE_NAME.lower() in dev['name'].lower() + and dev['hostApi'] == LOOPBACK_DEVICE_HOST_API): + return dev['index'], devices_info + + return None, devices_info + + def find_matching_texts(texts_without_punctuation): + """ + Find entries where text_without_punctuation matches the last entry, + going backwards until the first non-match is found. + + Args: + texts_without_punctuation: List of tuples (original_text, stripped_text) + + Returns: + List of tuples (original_text, stripped_text) matching the last entry's stripped text, + stopping at the first non-match + """ + if not texts_without_punctuation: + return [] + + # Get the stripped text from the last entry + last_stripped_text = texts_without_punctuation[-1][1] + + matching_entries = [] + + # Iterate through the list backwards + for entry in reversed(texts_without_punctuation): + original_text, stripped_text = entry + + # If we find a non-match, stop + if stripped_text != last_stripped_text: + break + + # Add the matching entry to our results + matching_entries.append((original_text, stripped_text)) + + # Reverse the results to maintain original order + matching_entries.reverse() + + return matching_entries def process_queue(): - global recorder, full_sentences, prev_text, displayed_text, rich_text_stored, text_time_deque, abrupt_stop, rapid_sentence_end_detection - + global recorder, full_sentences, prev_text, displayed_text, rich_text_stored, text_time_deque, abrupt_stop, rapid_sentence_end_detection, last_prob_complete, last_suggested_pause, last_pause while True: text = None # Initialize text to ensure it's defined @@ -216,15 +309,30 @@ def process_queue(): text = preprocess_text(text) current_time = time.time() - - sentence_end_marks = ['.', '!', '?', '。'] - - if text.endswith("..."): - suggested_pause = mid_sentence_detection_pause - elif text and text[-1] in sentence_end_marks and prev_text and prev_text[-1] in sentence_end_marks: - suggested_pause = end_of_sentence_detection_pause - else: - suggested_pause = unknown_sentence_detection_pause + text_time_deque.append((current_time, text)) + + # get text without ending punctuation + text_without_punctuation = strip_ending_punctuation(text) + + # print(f"Text: {text}, Text without punctuation: {text_without_punctuation}") + texts_without_punctuation.append((text, text_without_punctuation)) + + matches = find_matching_texts(texts_without_punctuation) + #print("Texts matching the last entry's stripped version:") + + added_pauses = 0 + contains_ellipses = False + for i, match in enumerate(matches): + same_text, stripped_punctuation = match + suggested_pause = get_suggested_whisper_pause(same_text) + added_pauses += suggested_pause + if ends_with_string(same_text, "..."): + contains_ellipses = True + + avg_pause = added_pauses / len(matches) if len(matches) > 0 else 0 + suggested_pause = avg_pause + # if contains_ellipses: + # suggested_pause += ellipsis_pause / 2 prev_text = text import string @@ -240,25 +348,23 @@ def process_queue(): # Interpolate rapid_sentence_end_detection based on prob_complete new_detection = interpolate_detection(prob_complete) - pause = new_detection + suggested_pause + # pause = new_detection + suggested_pause + pause = (new_detection + suggested_pause) * detection_speed + # **Add Additional Pause Based on Word Count** - extra_pause = additional_pause_based_on_words(text) - pause += extra_pause # Add the extra pause to the total pause duration + # extra_pause = additional_pause_based_on_words(text) + # pause += extra_pause # Add the extra pause to the total pause duration # Optionally, you can log this information for debugging if IS_DEBUG: print(f"Prob: {prob_complete:.2f}, " f"whisper {suggested_pause:.2f}, " f"model {new_detection:.2f}, " - f"extra {extra_pause:.2f}, " + # f"extra {extra_pause:.2f}, " f"final {pause:.2f} | {transtext} ") recorder.post_speech_silence_duration = pause - #if IS_DEBUG: print(f"Prob complete: {prob_complete:.2f}, pause whisper {suggested_pause:.2f}, model {new_detection:.2f}, final {pause:.2f} | {transtext} ") - - text_time_deque.append((current_time, text)) - # Remove old entries while text_time_deque and text_time_deque[0][0] < current_time - hard_break_even_on_background_noise: text_time_deque.popleft() @@ -284,17 +390,22 @@ def process_queue(): new_displayed_text = rich_text.plain - if new_displayed_text != displayed_text: - displayed_text = new_displayed_text - panel = Panel(rich_text, title="[bold green]Live Transcription[/bold green]", border_style="bold green") - live.update(panel) - rich_text_stored = rich_text + displayed_text = new_displayed_text + last_prob_complete = new_detection + last_suggested_pause = suggested_pause + last_pause = pause + panel = Panel(rich_text, title=f"[bold green]Prob complete:[/bold green] [bold yellow]{prob_complete:.2f}[/bold yellow], pause whisper [bold yellow]{suggested_pause:.2f}[/bold yellow], model [bold yellow]{new_detection:.2f}[/bold yellow], last detection [bold yellow]{last_detection_pause:.2f}[/bold yellow]", border_style="bold green") + live.update(panel) + rich_text_stored = rich_text text_queue.task_done() def process_text(text): - global recorder, full_sentences, prev_text, abrupt_stop - if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}") + global recorder, full_sentences, prev_text, abrupt_stop, last_detection_pause + last_prob_complete, last_suggested_pause, last_pause + last_detection_pause = recorder.post_speech_silence_duration + if IS_DEBUG: print(f"Model pause: {last_prob_complete:.2f}, Whisper pause: {last_suggested_pause:.2f}, final pause: {last_pause:.2f}, last_detection_pause: {last_detection_pause:.2f}") + #if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}") recorder.post_speech_silence_duration = unknown_sentence_detection_pause text = preprocess_text(text) text = text.rstrip() @@ -304,6 +415,7 @@ def process_text(text): full_sentences.append(text) prev_text = "" + text_detected("") if abrupt_stop: @@ -316,8 +428,9 @@ def process_text(text): recorder_config = { 'spinner': False, - 'model': 'large-v2', - 'realtime_model_type': 'medium.en', + 'model': 'large-v3', + #'realtime_model_type': 'medium.en', + 'realtime_model_type': 'tiny.en', 'language': 'en', 'silero_sensitivity': 0.4, 'webrtc_sensitivity': 3, @@ -327,10 +440,12 @@ def process_text(text): 'enable_realtime_transcription': True, 'realtime_processing_pause': 0.05, 'on_realtime_transcription_update': text_detected, - 'silero_deactivity_detection': False, + 'silero_deactivity_detection': True, 'early_transcription_on_silence': 0, 'beam_size': 5, - 'beam_size_realtime': 3, + 'beam_size_realtime': 1, + 'batch_size': 4, + 'realtime_batch_size': 4, 'no_log_file': True, 'initial_prompt_realtime': ( "End incomplete sentences with ellipses.\n" @@ -345,6 +460,17 @@ def process_text(text): if EXTENDED_LOGGING: recorder_config['level'] = logging.DEBUG + if USE_STEREO_MIX: + device_index, devices_info = find_stereo_mix_index() + if device_index is None: + live.stop() + console.print("[bold red]Stereo Mix device not found. Available audio devices are:\n[/bold red]") + console.print(devices_info, style="red") + sys.exit(1) + else: + recorder_config['input_device_index'] = device_index + console.print(f"Using audio device index {device_index} for Stereo Mix.", style="green") + recorder = AudioToTextRecorder(**recorder_config) initial_text = Panel(Text("Say something...", style="cyan bold"), title="[bold yellow]Waiting for Input[/bold yellow]", border_style="bold yellow")