diff --git a/transcribe_demo.py b/transcribe_demo.py index 6dd8972..31b2e5b 100644 --- a/transcribe_demo.py +++ b/transcribe_demo.py @@ -4,8 +4,6 @@ import io import os import speech_recognition as sr -import whisper -import torch from datetime import datetime, timedelta from queue import Queue @@ -20,6 +18,8 @@ def main(): choices=["tiny", "base", "small", "medium", "large"]) parser.add_argument("--non_english", action='store_true', help="Don't use the english model.") + parser.add_argument("--use_openai_api", action='store_true', + help="Using OPENAI_API_KEY enviromental variable") parser.add_argument("--energy_threshold", default=1000, help="Energy level for mic to detect.", type=int) parser.add_argument("--record_timeout", default=2, @@ -62,16 +62,24 @@ def main(): else: source = sr.Microphone(sample_rate=16000) - # Load / Download model - model = args.model - if args.model != "large" and not args.non_english: - model = model + ".en" - audio_model = whisper.load_model(model) + if args.use_openai_api: + import openai + # Load your API key from an environment variable or secret management service + openai.api_key = os.getenv('OPENAI_API_KEY') + else: + + import whisper + import torch + # Load / Download model + model = args.model + if args.model != "large" and not args.non_english: + model = model + ".en" + audio_model = whisper.load_model(model) record_timeout = args.record_timeout phrase_timeout = args.phrase_timeout - temp_file = NamedTemporaryFile().name + temp_file = NamedTemporaryFile(suffix='.wav').name transcription = [''] with source: @@ -120,8 +128,12 @@ def record_callback(_, audio:sr.AudioData) -> None: with open(temp_file, 'w+b') as f: f.write(wav_data.read()) - # Read the transcription. - result = audio_model.transcribe(temp_file, fp16=torch.cuda.is_available()) + if args.use_openai_api: + with open(temp_file, 'rb') as f: + result = openai.Audio.transcribe("whisper-1", f) + else: + # Read the transcription. + result = audio_model.transcribe(temp_file, fp16=torch.cuda.is_available()) text = result['text'].strip() # If we detected a pause between recordings, add a new item to our transcripion. @@ -132,7 +144,7 @@ def record_callback(_, audio:sr.AudioData) -> None: transcription[-1] = text # Clear the console to reprint the updated transcription. - os.system('cls' if os.name=='nt' else 'clear') + os.system('cls' if os.name == 'nt' else 'clear') for line in transcription: print(line) # Flush stdout. @@ -149,4 +161,4 @@ def record_callback(_, audio:sr.AudioData) -> None: if __name__ == "__main__": - main() \ No newline at end of file + main()