Skip to content

Commit

Permalink
improvements to algo
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Dec 20, 2024
1 parent 4e74b5c commit b2b777c
Showing 1 changed file with 175 additions and 49 deletions.
224 changes: 175 additions & 49 deletions tests/realtimestt_speechendpoint_binary_classified.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#IS_DEBUG = True
IS_DEBUG = False
USE_STEREO_MIX = False
LOOPBACK_DEVICE_NAME = "stereomix"
LOOPBACK_DEVICE_HOST_API = 0

import os
import re
Expand All @@ -21,6 +24,34 @@
])

EXTENDED_LOGGING = False
sentence_end_marks = ['.', '!', '?', '。']


detection_speed = 0.5 # set detection speed between 0.1 and 2.0



if detection_speed < 0.1:
detection_speed = 0.1
if detection_speed > 2.0:
detection_speed = 2.0

last_detection_pause = 0
last_prob_complete = 0
last_suggested_pause = 0
last_pause = 0
end_of_sentence_detection_pause = 0.3
maybe_end_of_sentence_detection_pause = 0.5
unknown_sentence_detection_pause = 0.8
ellipsis_pause = 1.7
punctuation_pause = 0.5
exclamation_pause = 0.4
question_pause = 0.3

hard_break_even_on_background_noise = 3.0
hard_break_even_on_background_noise_min_texts = 3
hard_break_even_on_background_noise_min_chars = 15
hard_break_even_on_background_noise_min_similarity = 0.99

if __name__ == '__main__':

Expand Down Expand Up @@ -50,8 +81,6 @@

tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir)
classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir)
# tokenizer = DistilBertTokenizerFast.from_pretrained(model_dir, force_download=True)
# classification_model = DistilBertForSequenceClassification.from_pretrained(model_dir, force_download=True)
classification_model.to(device)
classification_model.eval()

Expand Down Expand Up @@ -88,13 +117,6 @@ def get_completion_probability(sentence, model, tokenizer, device, max_length):
(0.0, 1.0),
(1.0, 0)
]
# anchor_points = [
# (0.0, 0.4),
# (0.5, 0.3),
# (0.8, 0.2),
# (0.9, 0.1),
# (1.0, 0)
# ]

def interpolate_detection(prob):
# Clamp probability between 0.0 and 1.0 just in case
Expand Down Expand Up @@ -144,17 +166,8 @@ def is_speech_finished(text):
recorder = None
displayed_text = ""
text_time_deque = deque()

# Default values
end_of_sentence_detection_pause = 0.3
unknown_sentence_detection_pause = 0.8
mid_sentence_detection_pause = 1.7
hard_break_even_on_background_noise = 3.0
hard_break_even_on_background_noise_min_texts = 3
hard_break_even_on_background_noise_min_chars = 15
hard_break_even_on_background_noise_min_similarity = 0.99
texts_without_punctuation = []
relisten_on_abrupt_stop = True

abrupt_stop = False
prev_text = ""

Expand All @@ -170,6 +183,18 @@ def preprocess_text(text):
def text_detected(text):
text_queue.put(text)

def ends_with_string(text: str, s: str):
if text.endswith(s):
return True
if len(text) > 1 and text[:-1].endswith(s):
return True
return False

def sentence_end(text: str):
if text and text[-1] in sentence_end_marks:
return True
return False

def additional_pause_based_on_words(text):
word_count = len(text.split())
pauses = {
Expand All @@ -182,10 +207,78 @@ def additional_pause_based_on_words(text):
6: 0.05,
}
return pauses.get(word_count, 0.0)

def strip_ending_punctuation(text):
"""Remove trailing periods and ellipses from text."""
text = text.rstrip()
for char in sentence_end_marks:
text = text.rstrip(char)
return text

def get_suggested_whisper_pause(text):
if ends_with_string(text, "..."):
return ellipsis_pause
elif ends_with_string(text, "."):
return punctuation_pause
elif ends_with_string(text, "!"):
return exclamation_pause
elif ends_with_string(text, "?"):
return question_pause
else:
return unknown_sentence_detection_pause

def find_stereo_mix_index():
import pyaudio
audio = pyaudio.PyAudio()
devices_info = ""
for i in range(audio.get_device_count()):
dev = audio.get_device_info_by_index(i)
devices_info += f"{dev['index']}: {dev['name']} (hostApi: {dev['hostApi']})\n"

if (LOOPBACK_DEVICE_NAME.lower() in dev['name'].lower()
and dev['hostApi'] == LOOPBACK_DEVICE_HOST_API):
return dev['index'], devices_info

return None, devices_info

def find_matching_texts(texts_without_punctuation):
"""
Find entries where text_without_punctuation matches the last entry,
going backwards until the first non-match is found.
Args:
texts_without_punctuation: List of tuples (original_text, stripped_text)
Returns:
List of tuples (original_text, stripped_text) matching the last entry's stripped text,
stopping at the first non-match
"""
if not texts_without_punctuation:
return []

# Get the stripped text from the last entry
last_stripped_text = texts_without_punctuation[-1][1]

matching_entries = []

# Iterate through the list backwards
for entry in reversed(texts_without_punctuation):
original_text, stripped_text = entry

# If we find a non-match, stop
if stripped_text != last_stripped_text:
break

# Add the matching entry to our results
matching_entries.append((original_text, stripped_text))

# Reverse the results to maintain original order
matching_entries.reverse()

return matching_entries

def process_queue():
global recorder, full_sentences, prev_text, displayed_text, rich_text_stored, text_time_deque, abrupt_stop, rapid_sentence_end_detection

global recorder, full_sentences, prev_text, displayed_text, rich_text_stored, text_time_deque, abrupt_stop, rapid_sentence_end_detection, last_prob_complete, last_suggested_pause, last_pause
while True:
text = None # Initialize text to ensure it's defined

Expand Down Expand Up @@ -216,15 +309,30 @@ def process_queue():

text = preprocess_text(text)
current_time = time.time()

sentence_end_marks = ['.', '!', '?', '。']

if text.endswith("..."):
suggested_pause = mid_sentence_detection_pause
elif text and text[-1] in sentence_end_marks and prev_text and prev_text[-1] in sentence_end_marks:
suggested_pause = end_of_sentence_detection_pause
else:
suggested_pause = unknown_sentence_detection_pause
text_time_deque.append((current_time, text))

# get text without ending punctuation
text_without_punctuation = strip_ending_punctuation(text)

# print(f"Text: {text}, Text without punctuation: {text_without_punctuation}")
texts_without_punctuation.append((text, text_without_punctuation))

matches = find_matching_texts(texts_without_punctuation)
#print("Texts matching the last entry's stripped version:")

added_pauses = 0
contains_ellipses = False
for i, match in enumerate(matches):
same_text, stripped_punctuation = match
suggested_pause = get_suggested_whisper_pause(same_text)
added_pauses += suggested_pause
if ends_with_string(same_text, "..."):
contains_ellipses = True

avg_pause = added_pauses / len(matches) if len(matches) > 0 else 0
suggested_pause = avg_pause
# if contains_ellipses:
# suggested_pause += ellipsis_pause / 2

prev_text = text
import string
Expand All @@ -240,25 +348,23 @@ def process_queue():
# Interpolate rapid_sentence_end_detection based on prob_complete
new_detection = interpolate_detection(prob_complete)

pause = new_detection + suggested_pause
# pause = new_detection + suggested_pause
pause = (new_detection + suggested_pause) * detection_speed

# **Add Additional Pause Based on Word Count**
extra_pause = additional_pause_based_on_words(text)
pause += extra_pause # Add the extra pause to the total pause duration
# extra_pause = additional_pause_based_on_words(text)
# pause += extra_pause # Add the extra pause to the total pause duration

# Optionally, you can log this information for debugging
if IS_DEBUG:
print(f"Prob: {prob_complete:.2f}, "
f"whisper {suggested_pause:.2f}, "
f"model {new_detection:.2f}, "
f"extra {extra_pause:.2f}, "
# f"extra {extra_pause:.2f}, "
f"final {pause:.2f} | {transtext} ")

recorder.post_speech_silence_duration = pause

#if IS_DEBUG: print(f"Prob complete: {prob_complete:.2f}, pause whisper {suggested_pause:.2f}, model {new_detection:.2f}, final {pause:.2f} | {transtext} ")

text_time_deque.append((current_time, text))

# Remove old entries
while text_time_deque and text_time_deque[0][0] < current_time - hard_break_even_on_background_noise:
text_time_deque.popleft()
Expand All @@ -284,17 +390,22 @@ def process_queue():

new_displayed_text = rich_text.plain

if new_displayed_text != displayed_text:
displayed_text = new_displayed_text
panel = Panel(rich_text, title="[bold green]Live Transcription[/bold green]", border_style="bold green")
live.update(panel)
rich_text_stored = rich_text
displayed_text = new_displayed_text
last_prob_complete = new_detection
last_suggested_pause = suggested_pause
last_pause = pause
panel = Panel(rich_text, title=f"[bold green]Prob complete:[/bold green] [bold yellow]{prob_complete:.2f}[/bold yellow], pause whisper [bold yellow]{suggested_pause:.2f}[/bold yellow], model [bold yellow]{new_detection:.2f}[/bold yellow], last detection [bold yellow]{last_detection_pause:.2f}[/bold yellow]", border_style="bold green")
live.update(panel)
rich_text_stored = rich_text

text_queue.task_done()

def process_text(text):
global recorder, full_sentences, prev_text, abrupt_stop
if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}")
global recorder, full_sentences, prev_text, abrupt_stop, last_detection_pause
last_prob_complete, last_suggested_pause, last_pause
last_detection_pause = recorder.post_speech_silence_duration
if IS_DEBUG: print(f"Model pause: {last_prob_complete:.2f}, Whisper pause: {last_suggested_pause:.2f}, final pause: {last_pause:.2f}, last_detection_pause: {last_detection_pause:.2f}")
#if IS_DEBUG: print(f"SENTENCE: post_speech_silence_duration: {recorder.post_speech_silence_duration}")
recorder.post_speech_silence_duration = unknown_sentence_detection_pause
text = preprocess_text(text)
text = text.rstrip()
Expand All @@ -304,6 +415,7 @@ def process_text(text):

full_sentences.append(text)
prev_text = ""

text_detected("")

if abrupt_stop:
Expand All @@ -316,8 +428,9 @@ def process_text(text):

recorder_config = {
'spinner': False,
'model': 'large-v2',
'realtime_model_type': 'medium.en',
'model': 'large-v3',
#'realtime_model_type': 'medium.en',
'realtime_model_type': 'tiny.en',
'language': 'en',
'silero_sensitivity': 0.4,
'webrtc_sensitivity': 3,
Expand All @@ -327,10 +440,12 @@ def process_text(text):
'enable_realtime_transcription': True,
'realtime_processing_pause': 0.05,
'on_realtime_transcription_update': text_detected,
'silero_deactivity_detection': False,
'silero_deactivity_detection': True,
'early_transcription_on_silence': 0,
'beam_size': 5,
'beam_size_realtime': 3,
'beam_size_realtime': 1,
'batch_size': 4,
'realtime_batch_size': 4,
'no_log_file': True,
'initial_prompt_realtime': (
"End incomplete sentences with ellipses.\n"
Expand All @@ -345,6 +460,17 @@ def process_text(text):
if EXTENDED_LOGGING:
recorder_config['level'] = logging.DEBUG

if USE_STEREO_MIX:
device_index, devices_info = find_stereo_mix_index()
if device_index is None:
live.stop()
console.print("[bold red]Stereo Mix device not found. Available audio devices are:\n[/bold red]")
console.print(devices_info, style="red")
sys.exit(1)
else:
recorder_config['input_device_index'] = device_index
console.print(f"Using audio device index {device_index} for Stereo Mix.", style="green")

recorder = AudioToTextRecorder(**recorder_config)

initial_text = Panel(Text("Say something...", style="cyan bold"), title="[bold yellow]Waiting for Input[/bold yellow]", border_style="bold yellow")
Expand Down

0 comments on commit b2b777c

Please sign in to comment.