-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtts_demo.py
145 lines (124 loc) · 4.94 KB
/
tts_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from typing import Optional, Tuple, List
from models import build_model, generate_speech, list_available_voices
from tqdm.auto import tqdm
import soundfile as sf
from pathlib import Path
import numpy as np
# Constants
SAMPLE_RATE = 24000
DEFAULT_MODEL_PATH = 'kokoro-v1_0.pth'
DEFAULT_OUTPUT_FILE = 'output.wav'
DEFAULT_LANGUAGE = 'a' # 'a' for American English, 'b' for British English
DEFAULT_TEXT = "Hello, welcome to this text-to-speech test."
# Configure tqdm for better Windows console support
tqdm.monitor_interval = 0
def print_menu():
"""Print the main menu options."""
print("\n=== Kokoro TTS Menu ===")
print("1. List available voices")
print("2. Generate speech")
print("3. Exit")
return input("Select an option (1-3): ").strip()
def select_voice(voices: List[str]) -> str:
"""Interactive voice selection."""
print("\nAvailable voices:")
for i, voice in enumerate(voices, 1):
print(f"{i}. {voice}")
while True:
try:
choice = input("\nSelect a voice number (or press Enter for default 'af_bella'): ").strip()
if not choice:
return "af_bella"
choice = int(choice)
if 1 <= choice <= len(voices):
return voices[choice - 1]
print("Invalid choice. Please try again.")
except ValueError:
print("Please enter a valid number.")
def get_text_input() -> str:
"""Get text input from user."""
print("\nEnter the text you want to convert to speech")
print("(or press Enter for default text)")
text = input("> ").strip()
return text if text else DEFAULT_TEXT
def get_speed() -> float:
"""Get speech speed from user."""
while True:
try:
speed = input("\nEnter speech speed (0.5-2.0, default 1.0): ").strip()
if not speed:
return 1.0
speed = float(speed)
if 0.5 <= speed <= 2.0:
return speed
print("Speed must be between 0.5 and 2.0")
except ValueError:
print("Please enter a valid number.")
def main() -> None:
try:
# Set up device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Build model
print("\nInitializing model...")
with tqdm(total=1, desc="Building model") as pbar:
model = build_model(DEFAULT_MODEL_PATH, device)
pbar.update(1)
while True:
choice = print_menu()
if choice == "1":
# List voices
voices = list_available_voices()
print("\nAvailable voices:")
for voice in voices:
print(f"- {voice}")
elif choice == "2":
# Generate speech
voices = list_available_voices()
if not voices:
print("No voices found! Please check the voices directory.")
continue
# Get user inputs
voice = select_voice(voices)
text = get_text_input()
speed = get_speed()
print(f"\nGenerating speech for: '{text}'")
print(f"Using voice: {voice}")
print(f"Speed: {speed}x")
# Generate speech
all_audio = []
generator = model(text, voice=f"voices/{voice}.pt", speed=speed, split_pattern=r'\n+')
with tqdm(desc="Generating speech") as pbar:
for gs, ps, audio in generator:
if audio is not None:
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()
all_audio.append(audio)
print(f"\nGenerated segment: {gs}")
print(f"Phonemes: {ps}")
pbar.update(1)
# Save audio
if all_audio:
final_audio = torch.cat(all_audio, dim=0)
output_path = Path(DEFAULT_OUTPUT_FILE)
sf.write(output_path, final_audio.numpy(), SAMPLE_RATE)
print(f"\nAudio saved to {output_path.absolute()}")
else:
print("Error: Failed to generate audio")
elif choice == "3":
print("\nGoodbye!")
break
else:
print("\nInvalid choice. Please try again.")
except Exception as e:
print(f"Error in main: {e}")
import traceback
traceback.print_exc()
finally:
# Cleanup
if 'model' in locals():
del model
torch.cuda.empty_cache()
if __name__ == "__main__":
main()