-
Notifications
You must be signed in to change notification settings - Fork 905
/
train.py
138 lines (118 loc) · 5.67 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import os
from time import sleep
import infolog
import tensorflow as tf
from hparams import hparams
from infolog import log
from tacotron.synthesize import tacotron_synthesize
from tacotron.train import tacotron_train
from wavenet_vocoder.train import wavenet_train
log = infolog.log
def save_seq(file, sequence, input_path):
'''Save Tacotron-2 training state to disk. (To skip for future runs)
'''
sequence = [str(int(s)) for s in sequence] + [input_path]
with open(file, 'w') as f:
f.write('|'.join(sequence))
def read_seq(file):
'''Load Tacotron-2 training state from disk. (To skip if not first run)
'''
if os.path.isfile(file):
with open(file, 'r') as f:
sequence = f.read().split('|')
return [bool(int(s)) for s in sequence[:-1]], sequence[-1]
else:
return [0, 0, 0], ''
def prepare_run(args):
modified_hp = hparams.parse(args.hparams)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = str(args.tf_log_level)
run_name = args.name or args.model
log_dir = os.path.join(args.base_dir, 'logs-{}'.format(run_name))
os.makedirs(log_dir, exist_ok=True)
infolog.init(os.path.join(log_dir, 'Terminal_train_log'), run_name, args.slack_url)
return log_dir, modified_hp
def train(args, log_dir, hparams):
state_file = os.path.join(log_dir, 'state_log')
#Get training states
(taco_state, GTA_state, wave_state), input_path = read_seq(state_file)
if not taco_state:
log('\n#############################################################\n')
log('Tacotron Train\n')
log('###########################################################\n')
checkpoint = tacotron_train(args, log_dir, hparams)
tf.reset_default_graph()
#Sleep 1/2 second to let previous graph close and avoid error messages while synthesis
sleep(0.5)
if checkpoint is None:
raise('Error occured while training Tacotron, Exiting!')
taco_state = 1
save_seq(state_file, [taco_state, GTA_state, wave_state], input_path)
else:
checkpoint = os.path.join(log_dir, 'taco_pretrained/')
if not GTA_state:
log('\n#############################################################\n')
log('Tacotron GTA Synthesis\n')
log('###########################################################\n')
input_path = tacotron_synthesize(args, hparams, checkpoint)
tf.reset_default_graph()
#Sleep 1/2 second to let previous graph close and avoid error messages while Wavenet is training
sleep(0.5)
GTA_state = 1
save_seq(state_file, [taco_state, GTA_state, wave_state], input_path)
else:
input_path = os.path.join('tacotron_' + args.output_dir, 'gta', 'map.txt')
if input_path == '' or input_path is None:
raise RuntimeError('input_path has an unpleasant value -> {}'.format(input_path))
if not wave_state:
log('\n#############################################################\n')
log('Wavenet Train\n')
log('###########################################################\n')
checkpoint = wavenet_train(args, log_dir, hparams, input_path)
if checkpoint is None:
raise ('Error occured while training Wavenet, Exiting!')
wave_state = 1
save_seq(state_file, [taco_state, GTA_state, wave_state], input_path)
if wave_state and GTA_state and taco_state:
log('TRAINING IS ALREADY COMPLETE!!')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='')
parser.add_argument('--hparams', default='',
help='Hyperparameter overrides as a comma-separated list of name=value pairs')
parser.add_argument('--tacotron_input', default='training_data/train.txt')
parser.add_argument('--wavenet_input', default='tacotron_output/gta/map.txt')
parser.add_argument('--name', help='Name of logging directory.')
parser.add_argument('--model', default='Tacotron-2')
parser.add_argument('--input_dir', default='training_data', help='folder to contain inputs sentences/targets')
parser.add_argument('--output_dir', default='output', help='folder to contain synthesized mel spectrograms')
parser.add_argument('--mode', default='synthesis', help='mode for synthesis of tacotron after training')
parser.add_argument('--GTA', default='True', help='Ground truth aligned synthesis, defaults to True, only considered in Tacotron synthesis mode')
parser.add_argument('--restore', type=bool, default=True, help='Set this to False to do a fresh training')
parser.add_argument('--summary_interval', type=int, default=250,
help='Steps between running summary ops')
parser.add_argument('--embedding_interval', type=int, default=5000,
help='Steps between updating embeddings projection visualization')
parser.add_argument('--checkpoint_interval', type=int, default=2500,
help='Steps between writing checkpoints')
parser.add_argument('--eval_interval', type=int, default=5000,
help='Steps between eval on test data')
parser.add_argument('--tacotron_train_steps', type=int, default=100000, help='total number of tacotron training steps')
parser.add_argument('--wavenet_train_steps', type=int, default=500000, help='total number of wavenet training steps')
parser.add_argument('--tf_log_level', type=int, default=1, help='Tensorflow C++ log level.')
parser.add_argument('--slack_url', default=None, help='slack webhook notification destination link')
args = parser.parse_args()
accepted_models = ['Tacotron', 'WaveNet', 'Tacotron-2']
if args.model not in accepted_models:
raise ValueError('please enter a valid model to train: {}'.format(accepted_models))
log_dir, hparams = prepare_run(args)
if args.model == 'Tacotron':
tacotron_train(args, log_dir, hparams)
elif args.model == 'WaveNet':
wavenet_train(args, log_dir, hparams, args.wavenet_input)
elif args.model == 'Tacotron-2':
train(args, log_dir, hparams)
else:
raise ValueError('Model provided {} unknown! {}'.format(args.model, accepted_models))
if __name__ == '__main__':
main()