-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
63 lines (48 loc) · 1.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from TTS.tts.configs.shared_configs import BaseDatasetConfig
import os
output_path = "data"
dataset_config = BaseDatasetConfig(
meta_file_train="metadata.csv", path=os.path.join(output_path, "bttm-out"), formatter="ljspeech"
)
# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
config = GlowTTSConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=8,
run_eval=True,
test_delay_epochs=-1,
epochs=250,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=25,
print_eval=False,
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
save_step=500,
lr=0.5
)
from TTS.utils.audio import AudioProcessor
ap = AudioProcessor.init_from_config(config)
from TTS.tts.utils.text.tokenizer import TTSTokenizer
tokenizer, config = TTSTokenizer.init_from_config(config)
from TTS.tts.datasets import load_tts_samples
print(config.eval_split_size)
train_samples, eval_samples = load_tts_samples(
dataset_config,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
# eval_split_size=1.0,
)
from TTS.tts.models.glow_tts import GlowTTS
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
from trainer import Trainer, TrainerArgs
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
trainer.fit()