diff --git a/transcribe_demo.py b/transcribe_demo.py index 14ec910..ab11923 100644 --- a/transcribe_demo.py +++ b/transcribe_demo.py @@ -3,34 +3,55 @@ import argparse import io import os -import speech_recognition as sr -import whisper -import torch - from datetime import datetime, timedelta from queue import Queue +from sys import platform from tempfile import NamedTemporaryFile from time import sleep -from sys import platform + +import speech_recognition as sr +import torch +import whisper def main(): parser = argparse.ArgumentParser() - parser.add_argument("--model", default="medium", help="Model to use", - choices=["tiny", "base", "small", "medium", "large"]) - parser.add_argument("--non_english", action='store_true', + parser.add_argument( + "--model", + default="medium", + help="Model to use", + choices=["tiny", "base", "small", "medium", "large"], + ) + parser.add_argument("--non_english", + action="store_true", help="Don't use the english model.") - parser.add_argument("--energy_threshold", default=1000, - help="Energy level for mic to detect.", type=int) - parser.add_argument("--record_timeout", default=2, - help="How real time the recording is in seconds.", type=float) - parser.add_argument("--phrase_timeout", default=3, - help="How much empty space between recordings before we " - "consider it a new line in the transcription.", type=float) - if 'linux' in platform: - parser.add_argument("--default_microphone", default='pulse', - help="Default microphone name for SpeechRecognition. " - "Run this with 'list' to view available Microphones.", type=str) + parser.add_argument( + "--energy_threshold", + default=1000, + help="Energy level for mic to detect.", + type=int, + ) + parser.add_argument( + "--record_timeout", + default=2, + help="How real time the recording is in seconds.", + type=float, + ) + parser.add_argument( + "--phrase_timeout", + default=3, + help="How much empty space between recordings before we " + "consider it a new line in the transcription.", + type=float, + ) + if "linux" in platform: + parser.add_argument( + "--default_microphone", + default="pulse", + help="Default microphone name for SpeechRecognition. " + "Run this with 'list' to view available Microphones.", + type=str, + ) args = parser.parse_args() # The last time a recording was retrieved from the queue. @@ -39,25 +60,32 @@ def main(): last_sample = bytes() # Thread safe Queue for passing data from the threaded recording callback. data_queue = Queue() - # We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends. + # We use SpeechRecognizer to record our audio because it has a nice feature + # where it can detect when speech ends. recorder = sr.Recognizer() recorder.energy_threshold = args.energy_threshold - # Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording. + # Definitely do this, dynamic energy compensation lowers the energy + # threshold dramatically to a point where the SpeechRecognizer never stops + # recording. recorder.dynamic_energy_threshold = False # Important for linux users. - # Prevents permanent application hang and crash by using the wrong Microphone - if 'linux' in platform: + # Prevents permanent application hang and crash by using the wrong + # Microphone + if "linux" in platform: mic_name = args.default_microphone - if not mic_name or mic_name == 'list': + if not mic_name or mic_name == "list": print("Available microphone devices are: ") - for index, name in enumerate(sr.Microphone.list_microphone_names()): - print(f"Microphone with name \"{name}\" found") + for index, name in enumerate( + sr.Microphone.list_microphone_names()): + print(f'Microphone with name "{name}" found') return else: - for index, name in enumerate(sr.Microphone.list_microphone_names()): + for index, name in enumerate( + sr.Microphone.list_microphone_names()): if mic_name in name: - source = sr.Microphone(sample_rate=16000, device_index=index) + source = sr.Microphone(sample_rate=16000, + device_index=index) break else: source = sr.Microphone(sample_rate=16000) @@ -72,14 +100,14 @@ def main(): phrase_timeout = args.phrase_timeout temp_file = NamedTemporaryFile().name - transcription = [''] + transcription = [""] with source: recorder.adjust_for_ambient_noise(source) - def record_callback(_, audio:sr.AudioData) -> None: + def record_callback(_, audio: sr.AudioData) -> None: """ - Threaded callback function to receive audio data when recordings finish. + Threaded callback function. Receives audio data when recordings finish. audio: An AudioData containing the recorded bytes. """ # Grab the raw bytes and push it into the thread safe queue. @@ -88,7 +116,9 @@ def record_callback(_, audio:sr.AudioData) -> None: # Create a background thread that will pass us raw audio bytes. # We could do this manually but SpeechRecognizer provides a nice helper. - recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout) + recorder.listen_in_background(source, + record_callback, + phrase_time_limit=record_timeout) # Cue the user that we're ready to go. print("Model loaded.\n") @@ -99,44 +129,48 @@ def record_callback(_, audio:sr.AudioData) -> None: # Pull raw recorded audio from the queue. if not data_queue.empty(): phrase_complete = False - # If enough time has passed between recordings, consider the phrase complete. - # Clear the current working audio buffer to start over with the new data. - if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout): + # If enough time has passed between recordings, consider the + # phrase complete. Clear the current working audio buffer to + # start over with the new data. + if phrase_time and now - phrase_time > timedelta( + seconds=phrase_timeout): last_sample = bytes() phrase_complete = True - # This is the last time we received new audio data from the queue. + # The last time we received new audio data from the queue. phrase_time = now - # Concatenate our current audio data with the latest audio data. + # Concatenate current audio data with the latest audio data. while not data_queue.empty(): data = data_queue.get() last_sample += data # Use AudioData to convert the raw data to wav data. - audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH) + audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, + source.SAMPLE_WIDTH) wav_data = io.BytesIO(audio_data.get_wav_data()) # Write wav data to the temporary file as bytes. - with open(temp_file, 'w+b') as f: + with open(temp_file, "w+b") as f: f.write(wav_data.read()) # Read the transcription. - result = audio_model.transcribe(temp_file, fp16=torch.cuda.is_available()) - text = result['text'].strip() + result = audio_model.transcribe(temp_file, + fp16=torch.cuda.is_available()) + text = result["text"].strip() - # If we detected a pause between recordings, add a new item to our transcription. - # Otherwise edit the existing one. + # If a pause detected between recordings, add a new item to our + # transcription. Otherwise edit the existing one. if phrase_complete: transcription.append(text) else: transcription[-1] = text # Clear the console to reprint the updated transcription. - os.system('cls' if os.name=='nt' else 'clear') + os.system("cls" if os.name == "nt" else "clear") for line in transcription: print(line) # Flush stdout. - print('', end='', flush=True) + print("", end="", flush=True) # Infinite loops are bad for processors, must sleep. sleep(0.25)