Skip to content

Commit

Permalink
bug-fixed and including sigmoid output from model
Browse files Browse the repository at this point in the history
  • Loading branch information
mimbres committed Sep 27, 2018
1 parent 81d2318 commit d19b781
Show file tree
Hide file tree
Showing 18 changed files with 141 additions and 79 deletions.
Binary file modified __pycache__/live_dataloader.cpython-36.pyc
Binary file not shown.
13 changes: 13 additions & 0 deletions checkpoints/03/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"exp_name": "03",
"model_type": "base",
"data_root": "/home/mimbres/Documents/Dataset/LJSpeech-1.1",
"max_epoch": 1000,
"batch_train": 16,
"batch_test": 5,
"load": null,
"save_interval": 50,
"sel_display": 9,
"silence_state_guide": -1,
"generate": null
}
9 changes: 9 additions & 0 deletions checkpoints/03/hist.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Total,L1,BCE,Att
0.31811372981593455,0.05411363613011655,0.26173089517951736,0.0022691986378944426
0.27744487956221237,0.03581270973941568,0.23952356732346183,0.0021086022405865695
0.27569847373431616,0.03505160883160676,0.238923375145389,0.00172349025419745
0.2739839101666165,0.03463885126930983,0.23855921627923876,0.0007858422662799905
0.2728093972138666,0.034298791065537154,0.23826296823637633,0.0002476381717573651
0.27221924900495387,0.034004817459031236,0.2380408748232865,0.00017355680674400984
0.27168442572834756,0.0337242923704461,0.2378063600001741,0.00015377319525789262
0.27124152447098143,0.03350268867723687,0.23759463926156363,0.00014419614141819866
Binary file added checkpoints/03/images/Sample 9: epoch = 0_att.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added checkpoints/03/images/Sample 9: epoch = 1_att.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added checkpoints/03/images/Sample 9: epoch = 3_att.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added checkpoints/03/images/Sample 9: epoch = 5_att.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added checkpoints/03/images/att_guide.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion config_template.json
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
{
"exp_name": "00",
"model_type": "base",
"data_root": "/home/mimbres/Documents/Dataset/LJSpeech-1.1",
"max_epoch": 1000,
"batch_train": 32,
"batch_train": 16,
"batch_test": 5,
"load": null,
"save_interval": 50,
Expand Down
4 changes: 2 additions & 2 deletions live_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data.dataset import Dataset
from nnmnkwii.datasets import FileDataSource, FileSourceDataset, MemoryCacheDataset

DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1'
#DATA_ROOT = '/mnt/ssd2/data/LJSpeech-1.1'
N_TRAIN = 13000 # N_TEST = 100 (13000~13099)
X_SPEC_MAX = 193.99077 # Required for feature normalization
X_MELSPEC_MAX = 0.04071
Expand Down Expand Up @@ -84,7 +84,7 @@ class LJSpeechDataset(Dataset):
mode = ['melspec'] : return index, text, melspec
mode = ['SSRN'] : return index, spec, melspec
'''
def __init__(self, data_root_dir=DATA_ROOT, train_mode=False , output_mode='melspec', transform=None, data_sel=None):
def __init__(self, data_root_dir=None, train_mode=False , output_mode='melspec', transform=None, data_sel=None):


self.wav_root_dir = data_root_dir + '/wavs/'
Expand Down
5 changes: 3 additions & 2 deletions model/FastTacotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,10 @@ def forward(self, x_text, x_melspec, forced_att=None):

# Decoding Mel-spectrogram, Y
Y = self.audio_dec(RQ) # Bx80xT with T=T_audio
Y_sig = F.sigmoid(Y)

if self.optional_output is True:
return Y, A, K, V, Q
return Y, Y_sig, A, K, V, Q
else:
return Y, A
return Y, Y_sig, A

Binary file modified model/__pycache__/FastTacotron.cpython-36.pyc
Binary file not shown.
55 changes: 45 additions & 10 deletions test_text2mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.autograd import Variable
from live_dataloader import LJSpeechDataset
from util.save_load_config import save_config, load_config
from tqdm import tqdm


DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1'


Expand Down Expand Up @@ -69,9 +71,8 @@
def generate_text2mel(model=nn.Module,
x_text=torch.autograd.Variable,
args=argparse.Namespace,
max_output_len=int,
save_to_file=True):

max_output_len=int):

model.eval()
torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated.

Expand All @@ -85,18 +86,41 @@ def generate_text2mel(model=nn.Module,
if (torch.sum(out_y[0,:,-5:]) < 1e-08):
break

if save_to_file is True:
save


return _melspec


def save_melspec(out_filepath, melspec):
np.save(out_filepath, melspec)


def save_melspec_img(out_filepath, melspec):
def display_spec(dt1, dt2, outfile_dir, title='unknown_spec'):
import seaborn as sns
plt.rcParams["figure.figsize"] = [10,5]
sns.set(font_scale=.7)

plt.subplot(211)
plt.pcolormesh(dt1, cmap='jet')

plt.subplot(212)
plt.pcolormesh(dt2, cmap='jet')

plt.title(args.exp_name + ','+title); plt.xlabel('Mel-spec frames')
os.makedirs(outfile_dir, exist_ok=True)
plt.savefig(outfile_dir + '/images/'+ title + '_mspec.png', bbox_inches='tight', dpi=220)
plt.close('all')


def display_att(att, outfile_dir, title='unknown_att'):
import seaborn as sns
plt.rcParams["figure.figsize"] = [7,7]
sns.set(font_scale=.7)
plt.pcolormesh(att, cmap='bone')
plt.title(title)
plt.xlabel('Mel-spec frames'); plt.ylabel('Text characters')
plt.savefig(outfile_dir + '/images/'+ title + '_att.png', bbox_inches='tight', dpi=220)
plt.close('all')
#plt.pcolormesh(guide, cmap='summer' ); plt.savefig(CHECKPOINT_DIR + '/images/'+ 'att_guide.png', bbox_inches='tight', dpi=100)
#plt.close('all')


def chr2int(text):
Expand All @@ -112,7 +136,10 @@ def chr2int(text):



#%% Text preparation:
#%% Text preparation:
MELSPEC_DIR = 'checkpoints/' + args.exp_name + '/gen_melspec'
os.makedirs(MELSPEC_DIR, exist_ok=True)

if isinstance(text_input, str): # case: ex) "Hello"
# Convert input string into one-hot vectors
x_text = chr2int(text_input)
Expand Down Expand Up @@ -146,7 +173,15 @@ def chr2int(text):
#n_batch = len(test_loader) # number of iteration for one epoch.
out_melspec = generate_text2mel(model=model, x_text=x_text, args=args, max_output_len=max_output_length)

# Save to file
# Save to .npy file
data_id = text_sel[batch_idx]
save_melspec(MELSPEC_DIR + '/gen_{0:05d}.npy'.format(data_id), out_melspec) # save gen<original data id>.npy

# Save images
display_spec(dt1=(x_melspec[0,:,:]).data.cpu().numpy(),
dt2=(out_melspec[0,:,:]).data.cpu().numpy(),
outfile_dir=MELSPEC_DIR,
title='generated_{0:05d}'.format(data_id))


else:
Expand Down
131 changes: 67 additions & 64 deletions train_text2mel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@
args = load_config(config_fpath)
args.exp_name = argv_inputs[1].lower()
save_config(args, config_fpath)
else:
args = load_config(config_fpath)
else:
args = load_config(config_fpath)


# Model type selection:
if args.model_type is 'base':
if args.model_type == 'base':
from model.FastTacotron import Text2Mel
elif args.model_type is 'BN':
from model.FastTacotron_BN import Text2Mel
Expand Down Expand Up @@ -118,7 +121,7 @@ def save_checkpoint(state):
save_config(args, CHECKPOINT_DIR + '/config.json')

#%% Data Loading
DATA_ROOT = '/mnt/ssd3/data/LJSpeech-1.1'
DATA_ROOT = args.data_root#'/mnt/ssd2/data/LJSpeech-1.1'

dset_train = LJSpeechDataset(data_root_dir=DATA_ROOT, train_mode=True, output_mode='melspec')
dset_test = LJSpeechDataset(data_root_dir=DATA_ROOT, train_mode=False, output_mode='melspec')
Expand Down Expand Up @@ -171,9 +174,9 @@ def train(epoch):
# break

optimizer.zero_grad()
out_y, out_att = model(x_text, x_melspec)
out_y, out_y_sig, out_att = model(x_text, x_melspec)

l1 = loss_L1(F.sigmoid(out_y[:,:,:-1]), x_melspec[:,:,1:])
l1 = loss_L1(out_y_sig[:,:,:-1], x_melspec[:,:,1:])
l2 = loss_BCE(out_y[:,:,:-1], x_melspec[:,:,1:])

# l3: Attention loss, W is guide matrices with BxNxT
Expand Down Expand Up @@ -201,74 +204,74 @@ def train(epoch):
if ((epoch in [1,3,5,10,20,30,40]) | (epoch%args.save_interval is 0)) & (select_data in data_idx ):
sel = np.where(data_idx.cpu()==select_data)[0].data[0]

out_y_cpu = (out_y[sel,:,:]).data.cpu().numpy()
out_y_sig_cpu = (out_y_sig[sel,:,:]).data.cpu().numpy()
out_att_cpu = (out_att[sel,:,:]).data.cpu().numpy()
#org_text = (x_text[sel,:]).data.cpu().numpy()
org_melspec =(x_melspec[sel,:,:]).data.cpu().numpy()

display_spec(out_y_cpu, org_melspec, 'Sample {}: epoch = {}'.format(select_data, epoch))
display_spec(out_y_sig_cpu, org_melspec, 'Sample {}: epoch = {}'.format(select_data, epoch))
display_att(out_att_cpu, W[sel,:,:], 'Sample {}: epoch = {}'.format(select_data, epoch))

return train_loss

def generate_text2mel(model_load=None, new_text=None):
'''
Args:
- text: <str> or <list index(in test data) to display>. ex) 'Hello' or [0, 3, 5]
- model_load: <existing model to load> or <exp_name>. exp_name must have a directory of checkpoint containing config.json
'''

if isinstance(model_load, str):
import os, shutil, pprint #, argparse
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from live_dataloader import LJSpeechDataset
from util.save_load_config import save_config, load_config
from model.FastTacotron import Text2Mel










model.eval()
torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated.

for batch_idx, (data_idx, x_text , x_melspec_org, zs) in tqdm(enumerate(test_loader)):
if USE_GPU:
x_text, x_melspec_org = Variable(x_text.cuda().long(), requires_grad=False), Variable(x_melspec_org.cuda().float(), requires_grad=False)
else:
x_text, x_melspec_org = Variable(x_text.long(), requires_grad=False), Variable(x_melspec_org.float(), requires_grad=False)
if batch_idx is disp_sel:
break

x_melspec = Variable(torch.FloatTensor(1,80,1).cuda()*0, requires_grad=False)

import matplotlib.pyplot as plt

for i in range(220):
out_y, out_att = model(x_text[:,:], x_melspec)
x_melspec = torch.cat((x_melspec, out_y[:,:,-1].view(1,80,-1)), dim=2)
#plt.imshow(out_att[0,:,:].data.cpu().numpy())
#plt.show()


plt.imshow(x_melspec[0,:,:].data.cpu().numpy())
plt.show()

plt.imshow(x_melspec_org[0,:,:].data.cpu().numpy())
plt.show()

plt.imshow(out_att[0,:,:].data.cpu().numpy())
plt.show()

#def generate_text2mel(model_load=None, new_text=None):
# '''
# Args:
# - text: <str> or <list index(in test data) to display>. ex) 'Hello' or [0, 3, 5]
# - model_load: <existing model to load> or <exp_name>. exp_name must have a directory of checkpoint containing config.json
# '''
#
# if isinstance(model_load, str):
# import os, shutil, pprint #, argparse
# import numpy as np
# import torch
# import torch.nn as nn
# from torch.utils.data import DataLoader
# from torch.autograd import Variable
# from live_dataloader import LJSpeechDataset
# from util.save_load_config import save_config, load_config
# from model.FastTacotron import Text2Mel
#
#
#
#
#
#
#
#
#
#
# model.eval()
# torch.set_grad_enabled(False) # Pytorch 0.4: "volatile=True" is deprecated.
#
# for batch_idx, (data_idx, x_text , x_melspec_org, zs) in tqdm(enumerate(test_loader)):
# if USE_GPU:
# x_text, x_melspec_org = Variable(x_text.cuda().long(), requires_grad=False), Variable(x_melspec_org.cuda().float(), requires_grad=False)
# else:
# x_text, x_melspec_org = Variable(x_text.long(), requires_grad=False), Variable(x_melspec_org.float(), requires_grad=False)
# if batch_idx is disp_sel:
# break
#
# x_melspec = Variable(torch.FloatTensor(1,80,1).cuda()*0, requires_grad=False)
#
# import matplotlib.pyplot as plt
#
# for i in range(220):
# out_y, out_att = model(x_text[:,:], x_melspec)
# x_melspec = torch.cat((x_melspec, out_y[:,:,-1].view(1,80,-1)), dim=2)
# #plt.imshow(out_att[0,:,:].data.cpu().numpy())
# #plt.show()
#
#
# plt.imshow(x_melspec[0,:,:].data.cpu().numpy())
# plt.show()
#
# plt.imshow(x_melspec_org[0,:,:].data.cpu().numpy())
# plt.show()
#
# plt.imshow(out_att[0,:,:].data.cpu().numpy())
# plt.show()
#

#%% Train Main Loop
df_hist = pd.DataFrame(columns=('Total', 'L1', 'BCE','Att'))
Expand Down

0 comments on commit d19b781

Please sign in to comment.