diff --git a/__pycache__/live_dataloader.cpython-36.pyc b/__pycache__/live_dataloader.cpython-36.pyc index f87c44e..bdb078d 100644 Binary files a/__pycache__/live_dataloader.cpython-36.pyc and b/__pycache__/live_dataloader.cpython-36.pyc differ diff --git a/checkpoints/03/config.json b/checkpoints/03/config.json new file mode 100644 index 0000000..f48c629 --- /dev/null +++ b/checkpoints/03/config.json @@ -0,0 +1,13 @@ +{ +"exp_name": "03", +"model_type": "base", +"data_root": "/home/mimbres/Documents/Dataset/LJSpeech-1.1", +"max_epoch": 1000, +"batch_train": 16, +"batch_test": 5, +"load": null, +"save_interval": 50, +"sel_display": 9, +"silence_state_guide": -1, +"generate": null +} \ No newline at end of file diff --git a/checkpoints/03/hist.csv b/checkpoints/03/hist.csv new file mode 100644 index 0000000..bb6aab3 --- /dev/null +++ b/checkpoints/03/hist.csv @@ -0,0 +1,9 @@ +Total,L1,BCE,Att +0.31811372981593455,0.05411363613011655,0.26173089517951736,0.0022691986378944426 +0.27744487956221237,0.03581270973941568,0.23952356732346183,0.0021086022405865695 +0.27569847373431616,0.03505160883160676,0.238923375145389,0.00172349025419745 +0.2739839101666165,0.03463885126930983,0.23855921627923876,0.0007858422662799905 +0.2728093972138666,0.034298791065537154,0.23826296823637633,0.0002476381717573651 +0.27221924900495387,0.034004817459031236,0.2380408748232865,0.00017355680674400984 +0.27168442572834756,0.0337242923704461,0.2378063600001741,0.00015377319525789262 +0.27124152447098143,0.03350268867723687,0.23759463926156363,0.00014419614141819866 diff --git a/checkpoints/03/images/Sample 9: epoch = 0_att.png b/checkpoints/03/images/Sample 9: epoch = 0_att.png new file mode 100644 index 0000000..0fc480f Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 0_att.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 0_mspec.png b/checkpoints/03/images/Sample 9: epoch = 0_mspec.png new file mode 100644 index 0000000..9dc75d2 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 0_mspec.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 1_att.png b/checkpoints/03/images/Sample 9: epoch = 1_att.png new file mode 100644 index 0000000..5316170 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 1_att.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 1_mspec.png b/checkpoints/03/images/Sample 9: epoch = 1_mspec.png new file mode 100644 index 0000000..a3b5839 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 1_mspec.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 3_att.png b/checkpoints/03/images/Sample 9: epoch = 3_att.png new file mode 100644 index 0000000..8197e40 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 3_att.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 3_mspec.png b/checkpoints/03/images/Sample 9: epoch = 3_mspec.png new file mode 100644 index 0000000..40dc201 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 3_mspec.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 5_att.png b/checkpoints/03/images/Sample 9: epoch = 5_att.png new file mode 100644 index 0000000..324d0ff Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 5_att.png differ diff --git a/checkpoints/03/images/Sample 9: epoch = 5_mspec.png b/checkpoints/03/images/Sample 9: epoch = 5_mspec.png new file mode 100644 index 0000000..b1ef218 Binary files /dev/null and b/checkpoints/03/images/Sample 9: epoch = 5_mspec.png differ diff --git a/checkpoints/03/images/att_guide.png b/checkpoints/03/images/att_guide.png new file mode 100644 index 0000000..8aeeed8 Binary files /dev/null and b/checkpoints/03/images/att_guide.png differ diff --git a/config_template.json b/config_template.json index 976f5a9..5e6e02a 100644 --- a/config_template.json +++ b/config_template.json @@ -1,8 +1,9 @@ { "exp_name": "00", "model_type": "base", +"data_root": "/home/mimbres/Documents/Dataset/LJSpeech-1.1", "max_epoch": 1000, -"batch_train": 32, +"batch_train": 16, "batch_test": 5, "load": null, "save_interval": 50, diff --git a/live_dataloader.py b/live_dataloader.py index d9a956e..21f291b 100644 --- a/live_dataloader.py +++ b/live_dataloader.py @@ -15,7 +15,7 @@ from torch.utils.data.dataset import Dataset from nnmnkwii.datasets import FileDataSource, FileSourceDataset, MemoryCacheDataset -DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1' +#DATA_ROOT = '/mnt/ssd2/data/LJSpeech-1.1' N_TRAIN = 13000 # N_TEST = 100 (13000~13099) X_SPEC_MAX = 193.99077 # Required for feature normalization X_MELSPEC_MAX = 0.04071 @@ -84,7 +84,7 @@ class LJSpeechDataset(Dataset): mode = ['melspec'] : return index, text, melspec mode = ['SSRN'] : return index, spec, melspec ''' - def __init__(self, data_root_dir=DATA_ROOT, train_mode=False , output_mode='melspec', transform=None, data_sel=None): + def __init__(self, data_root_dir=None, train_mode=False , output_mode='melspec', transform=None, data_sel=None): self.wav_root_dir = data_root_dir + '/wavs/' diff --git a/model/FastTacotron.py b/model/FastTacotron.py index f917dfc..524404b 100644 --- a/model/FastTacotron.py +++ b/model/FastTacotron.py @@ -225,9 +225,10 @@ def forward(self, x_text, x_melspec, forced_att=None): # Decoding Mel-spectrogram, Y Y = self.audio_dec(RQ) # Bx80xT with T=T_audio + Y_sig = F.sigmoid(Y) if self.optional_output is True: - return Y, A, K, V, Q + return Y, Y_sig, A, K, V, Q else: - return Y, A + return Y, Y_sig, A \ No newline at end of file diff --git a/model/__pycache__/FastTacotron.cpython-36.pyc b/model/__pycache__/FastTacotron.cpython-36.pyc index abdc3ad..f2de5bd 100644 Binary files a/model/__pycache__/FastTacotron.cpython-36.pyc and b/model/__pycache__/FastTacotron.cpython-36.pyc differ diff --git a/test_text2mel.py b/test_text2mel.py index 04a40f9..bf607d9 100644 --- a/test_text2mel.py +++ b/test_text2mel.py @@ -20,12 +20,14 @@ import numpy as np import torch import torch.nn as nn +import matplotlib.pyplot as plt from torch.utils.data import DataLoader from torch.autograd import Variable from live_dataloader import LJSpeechDataset from util.save_load_config import save_config, load_config from tqdm import tqdm + DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1' @@ -69,9 +71,8 @@ def generate_text2mel(model=nn.Module, x_text=torch.autograd.Variable, args=argparse.Namespace, - max_output_len=int, - save_to_file=True): - + max_output_len=int): + model.eval() torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated. @@ -85,10 +86,6 @@ def generate_text2mel(model=nn.Module, if (torch.sum(out_y[0,:,-5:]) < 1e-08): break - if save_to_file is True: - save - - return _melspec @@ -96,7 +93,34 @@ def save_melspec(out_filepath, melspec): np.save(out_filepath, melspec) -def save_melspec_img(out_filepath, melspec): +def display_spec(dt1, dt2, outfile_dir, title='unknown_spec'): + import seaborn as sns + plt.rcParams["figure.figsize"] = [10,5] + sns.set(font_scale=.7) + + plt.subplot(211) + plt.pcolormesh(dt1, cmap='jet') + + plt.subplot(212) + plt.pcolormesh(dt2, cmap='jet') + + plt.title(args.exp_name + ','+title); plt.xlabel('Mel-spec frames') + os.makedirs(outfile_dir, exist_ok=True) + plt.savefig(outfile_dir + '/images/'+ title + '_mspec.png', bbox_inches='tight', dpi=220) + plt.close('all') + + +def display_att(att, outfile_dir, title='unknown_att'): + import seaborn as sns + plt.rcParams["figure.figsize"] = [7,7] + sns.set(font_scale=.7) + plt.pcolormesh(att, cmap='bone') + plt.title(title) + plt.xlabel('Mel-spec frames'); plt.ylabel('Text characters') + plt.savefig(outfile_dir + '/images/'+ title + '_att.png', bbox_inches='tight', dpi=220) + plt.close('all') + #plt.pcolormesh(guide, cmap='summer' ); plt.savefig(CHECKPOINT_DIR + '/images/'+ 'att_guide.png', bbox_inches='tight', dpi=100) + #plt.close('all') def chr2int(text): @@ -112,7 +136,10 @@ def chr2int(text): -#%% Text preparation: +#%% Text preparation: +MELSPEC_DIR = 'checkpoints/' + args.exp_name + '/gen_melspec' +os.makedirs(MELSPEC_DIR, exist_ok=True) + if isinstance(text_input, str): # case: ex) "Hello" # Convert input string into one-hot vectors x_text = chr2int(text_input) @@ -146,7 +173,15 @@ def chr2int(text): #n_batch = len(test_loader) # number of iteration for one epoch. out_melspec = generate_text2mel(model=model, x_text=x_text, args=args, max_output_len=max_output_length) - # Save to file + # Save to .npy file + data_id = text_sel[batch_idx] + save_melspec(MELSPEC_DIR + '/gen_{0:05d}.npy'.format(data_id), out_melspec) # save gen.npy + + # Save images + display_spec(dt1=(x_melspec[0,:,:]).data.cpu().numpy(), + dt2=(out_melspec[0,:,:]).data.cpu().numpy(), + outfile_dir=MELSPEC_DIR, + title='generated_{0:05d}'.format(data_id)) else: diff --git a/train_text2mel.py b/train_text2mel.py index b0996bb..a319013 100644 --- a/train_text2mel.py +++ b/train_text2mel.py @@ -45,11 +45,14 @@ args = load_config(config_fpath) args.exp_name = argv_inputs[1].lower() save_config(args, config_fpath) + else: + args = load_config(config_fpath) else: args = load_config(config_fpath) + # Model type selection: -if args.model_type is 'base': +if args.model_type == 'base': from model.FastTacotron import Text2Mel elif args.model_type is 'BN': from model.FastTacotron_BN import Text2Mel @@ -118,7 +121,7 @@ def save_checkpoint(state): save_config(args, CHECKPOINT_DIR + '/config.json') #%% Data Loading -DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1' +DATA_ROOT = args.data_root#'/mnt/ssd2/data/LJSpeech-1.1' dset_train = LJSpeechDataset(data_root_dir=DATA_ROOT, train_mode=True, output_mode='melspec') dset_test = LJSpeechDataset(data_root_dir=DATA_ROOT, train_mode=False, output_mode='melspec') @@ -171,9 +174,9 @@ def train(epoch): # break optimizer.zero_grad() - out_y, out_att = model(x_text, x_melspec) + out_y, out_y_sig, out_att = model(x_text, x_melspec) - l1 = loss_L1(F.sigmoid(out_y[:,:,:-1]), x_melspec[:,:,1:]) + l1 = loss_L1(out_y_sig[:,:,:-1], x_melspec[:,:,1:]) l2 = loss_BCE(out_y[:,:,:-1], x_melspec[:,:,1:]) # l3: Attention loss, W is guide matrices with BxNxT @@ -201,74 +204,74 @@ def train(epoch): if ((epoch in [1,3,5,10,20,30,40]) | (epoch%args.save_interval is 0)) & (select_data in data_idx ): sel = np.where(data_idx.cpu()==select_data)[0].data[0] - out_y_cpu = (out_y[sel,:,:]).data.cpu().numpy() + out_y_sig_cpu = (out_y_sig[sel,:,:]).data.cpu().numpy() out_att_cpu = (out_att[sel,:,:]).data.cpu().numpy() #org_text = (x_text[sel,:]).data.cpu().numpy() org_melspec =(x_melspec[sel,:,:]).data.cpu().numpy() - display_spec(out_y_cpu, org_melspec, 'Sample {}: epoch = {}'.format(select_data, epoch)) + display_spec(out_y_sig_cpu, org_melspec, 'Sample {}: epoch = {}'.format(select_data, epoch)) display_att(out_att_cpu, W[sel,:,:], 'Sample {}: epoch = {}'.format(select_data, epoch)) return train_loss -def generate_text2mel(model_load=None, new_text=None): - ''' - Args: - - text: or . ex) 'Hello' or [0, 3, 5] - - model_load: or . exp_name must have a directory of checkpoint containing config.json - ''' - - if isinstance(model_load, str): - import os, shutil, pprint #, argparse - import numpy as np - import torch - import torch.nn as nn - from torch.utils.data import DataLoader - from torch.autograd import Variable - from live_dataloader import LJSpeechDataset - from util.save_load_config import save_config, load_config - from model.FastTacotron import Text2Mel - - - - - - - - - - - model.eval() - torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated. - - for batch_idx, (data_idx, x_text , x_melspec_org, zs) in tqdm(enumerate(test_loader)): - if USE_GPU: - x_text, x_melspec_org = Variable(x_text.cuda().long(), requires_grad=False), Variable(x_melspec_org.cuda().float(), requires_grad=False) - else: - x_text, x_melspec_org = Variable(x_text.long(), requires_grad=False), Variable(x_melspec_org.float(), requires_grad=False) - if batch_idx is disp_sel: - break - - x_melspec = Variable(torch.FloatTensor(1,80,1).cuda()*0, requires_grad=False) - - import matplotlib.pyplot as plt - - for i in range(220): - out_y, out_att = model(x_text[:,:], x_melspec) - x_melspec = torch.cat((x_melspec, out_y[:,:,-1].view(1,80,-1)), dim=2) - #plt.imshow(out_att[0,:,:].data.cpu().numpy()) - #plt.show() - - - plt.imshow(x_melspec[0,:,:].data.cpu().numpy()) - plt.show() - - plt.imshow(x_melspec_org[0,:,:].data.cpu().numpy()) - plt.show() - - plt.imshow(out_att[0,:,:].data.cpu().numpy()) - plt.show() - +#def generate_text2mel(model_load=None, new_text=None): +# ''' +# Args: +# - text: or . ex) 'Hello' or [0, 3, 5] +# - model_load: or . exp_name must have a directory of checkpoint containing config.json +# ''' +# +# if isinstance(model_load, str): +# import os, shutil, pprint #, argparse +# import numpy as np +# import torch +# import torch.nn as nn +# from torch.utils.data import DataLoader +# from torch.autograd import Variable +# from live_dataloader import LJSpeechDataset +# from util.save_load_config import save_config, load_config +# from model.FastTacotron import Text2Mel +# +# +# +# +# +# +# +# +# +# +# model.eval() +# torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated. +# +# for batch_idx, (data_idx, x_text , x_melspec_org, zs) in tqdm(enumerate(test_loader)): +# if USE_GPU: +# x_text, x_melspec_org = Variable(x_text.cuda().long(), requires_grad=False), Variable(x_melspec_org.cuda().float(), requires_grad=False) +# else: +# x_text, x_melspec_org = Variable(x_text.long(), requires_grad=False), Variable(x_melspec_org.float(), requires_grad=False) +# if batch_idx is disp_sel: +# break +# +# x_melspec = Variable(torch.FloatTensor(1,80,1).cuda()*0, requires_grad=False) +# +# import matplotlib.pyplot as plt +# +# for i in range(220): +# out_y, out_att = model(x_text[:,:], x_melspec) +# x_melspec = torch.cat((x_melspec, out_y[:,:,-1].view(1,80,-1)), dim=2) +# #plt.imshow(out_att[0,:,:].data.cpu().numpy()) +# #plt.show() +# +# +# plt.imshow(x_melspec[0,:,:].data.cpu().numpy()) +# plt.show() +# +# plt.imshow(x_melspec_org[0,:,:].data.cpu().numpy()) +# plt.show() +# +# plt.imshow(out_att[0,:,:].data.cpu().numpy()) +# plt.show() +# #%% Train Main Loop df_hist = pd.DataFrame(columns=('Total', 'L1', 'BCE','Att'))