Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint #38

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open

Lint #38

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 79 additions & 45 deletions transcribe_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,55 @@
import argparse
import io
import os
import speech_recognition as sr
import whisper
import torch

from datetime import datetime, timedelta
from queue import Queue
from sys import platform
from tempfile import NamedTemporaryFile
from time import sleep
from sys import platform

import speech_recognition as sr
import torch
import whisper


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="medium", help="Model to use",
choices=["tiny", "base", "small", "medium", "large"])
parser.add_argument("--non_english", action='store_true',
parser.add_argument(
"--model",
default="medium",
help="Model to use",
choices=["tiny", "base", "small", "medium", "large"],
)
parser.add_argument("--non_english",
action="store_true",
help="Don't use the english model.")
parser.add_argument("--energy_threshold", default=1000,
help="Energy level for mic to detect.", type=int)
parser.add_argument("--record_timeout", default=2,
help="How real time the recording is in seconds.", type=float)
parser.add_argument("--phrase_timeout", default=3,
help="How much empty space between recordings before we "
"consider it a new line in the transcription.", type=float)
if 'linux' in platform:
parser.add_argument("--default_microphone", default='pulse',
help="Default microphone name for SpeechRecognition. "
"Run this with 'list' to view available Microphones.", type=str)
parser.add_argument(
"--energy_threshold",
default=1000,
help="Energy level for mic to detect.",
type=int,
)
parser.add_argument(
"--record_timeout",
default=2,
help="How real time the recording is in seconds.",
type=float,
)
parser.add_argument(
"--phrase_timeout",
default=3,
help="How much empty space between recordings before we "
"consider it a new line in the transcription.",
type=float,
)
if "linux" in platform:
parser.add_argument(
"--default_microphone",
default="pulse",
help="Default microphone name for SpeechRecognition. "
"Run this with 'list' to view available Microphones.",
type=str,
)
args = parser.parse_args()

# The last time a recording was retrieved from the queue.
Expand All @@ -39,25 +60,32 @@ def main():
last_sample = bytes()
# Thread safe Queue for passing data from the threaded recording callback.
data_queue = Queue()
# We use SpeechRecognizer to record our audio because it has a nice feature where it can detect when speech ends.
# We use SpeechRecognizer to record our audio because it has a nice feature
# where it can detect when speech ends.
recorder = sr.Recognizer()
recorder.energy_threshold = args.energy_threshold
# Definitely do this, dynamic energy compensation lowers the energy threshold dramatically to a point where the SpeechRecognizer never stops recording.
# Definitely do this, dynamic energy compensation lowers the energy
# threshold dramatically to a point where the SpeechRecognizer never stops
# recording.
recorder.dynamic_energy_threshold = False

# Important for linux users.
# Prevents permanent application hang and crash by using the wrong Microphone
if 'linux' in platform:
# Prevents permanent application hang and crash by using the wrong
# Microphone
if "linux" in platform:
mic_name = args.default_microphone
if not mic_name or mic_name == 'list':
if not mic_name or mic_name == "list":
print("Available microphone devices are: ")
for index, name in enumerate(sr.Microphone.list_microphone_names()):
print(f"Microphone with name \"{name}\" found")
for index, name in enumerate(
sr.Microphone.list_microphone_names()):
print(f'Microphone with name "{name}" found')
return
else:
for index, name in enumerate(sr.Microphone.list_microphone_names()):
for index, name in enumerate(
sr.Microphone.list_microphone_names()):
if mic_name in name:
source = sr.Microphone(sample_rate=16000, device_index=index)
source = sr.Microphone(sample_rate=16000,
device_index=index)
break
else:
source = sr.Microphone(sample_rate=16000)
Expand All @@ -72,14 +100,14 @@ def main():
phrase_timeout = args.phrase_timeout

temp_file = NamedTemporaryFile().name
transcription = ['']
transcription = [""]

with source:
recorder.adjust_for_ambient_noise(source)

def record_callback(_, audio:sr.AudioData) -> None:
def record_callback(_, audio: sr.AudioData) -> None:
"""
Threaded callback function to receive audio data when recordings finish.
Threaded callback function. Receives audio data when recordings finish.
audio: An AudioData containing the recorded bytes.
"""
# Grab the raw bytes and push it into the thread safe queue.
Expand All @@ -88,7 +116,9 @@ def record_callback(_, audio:sr.AudioData) -> None:

# Create a background thread that will pass us raw audio bytes.
# We could do this manually but SpeechRecognizer provides a nice helper.
recorder.listen_in_background(source, record_callback, phrase_time_limit=record_timeout)
recorder.listen_in_background(source,
record_callback,
phrase_time_limit=record_timeout)

# Cue the user that we're ready to go.
print("Model loaded.\n")
Expand All @@ -99,44 +129,48 @@ def record_callback(_, audio:sr.AudioData) -> None:
# Pull raw recorded audio from the queue.
if not data_queue.empty():
phrase_complete = False
# If enough time has passed between recordings, consider the phrase complete.
# Clear the current working audio buffer to start over with the new data.
if phrase_time and now - phrase_time > timedelta(seconds=phrase_timeout):
# If enough time has passed between recordings, consider the
# phrase complete. Clear the current working audio buffer to
# start over with the new data.
if phrase_time and now - phrase_time > timedelta(
seconds=phrase_timeout):
last_sample = bytes()
phrase_complete = True
# This is the last time we received new audio data from the queue.
# The last time we received new audio data from the queue.
phrase_time = now

# Concatenate our current audio data with the latest audio data.
# Concatenate current audio data with the latest audio data.
while not data_queue.empty():
data = data_queue.get()
last_sample += data

# Use AudioData to convert the raw data to wav data.
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE, source.SAMPLE_WIDTH)
audio_data = sr.AudioData(last_sample, source.SAMPLE_RATE,
source.SAMPLE_WIDTH)
wav_data = io.BytesIO(audio_data.get_wav_data())

# Write wav data to the temporary file as bytes.
with open(temp_file, 'w+b') as f:
with open(temp_file, "w+b") as f:
f.write(wav_data.read())

# Read the transcription.
result = audio_model.transcribe(temp_file, fp16=torch.cuda.is_available())
text = result['text'].strip()
result = audio_model.transcribe(temp_file,
fp16=torch.cuda.is_available())
text = result["text"].strip()

# If we detected a pause between recordings, add a new item to our transcription.
# Otherwise edit the existing one.
# If a pause detected between recordings, add a new item to our
# transcription. Otherwise edit the existing one.
if phrase_complete:
transcription.append(text)
else:
transcription[-1] = text

# Clear the console to reprint the updated transcription.
os.system('cls' if os.name=='nt' else 'clear')
os.system("cls" if os.name == "nt" else "clear")
for line in transcription:
print(line)
# Flush stdout.
print('', end='', flush=True)
print("", end="", flush=True)

# Infinite loops are bad for processors, must sleep.
sleep(0.25)
Expand Down