forked from vincentherrmann/pytorch-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 4
/
wavenet_training.py
125 lines (107 loc) · 4.55 KB
/
wavenet_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.optim as optim
import torch.utils.data
import time
from datetime import datetime
import torch.nn.functional as F
from torch.autograd import Variable
from model_logging import Logger
from wavenet_modules import *
def print_last_loss(opt):
print("loss: ", opt.losses[-1])
def print_last_validation_result(opt):
print("validation loss: ", opt.validation_results[-1])
class WavenetTrainer:
def __init__(self,
model,
dataset,
optimizer=optim.Adam,
lr=0.001,
weight_decay=0,
gradient_clipping=None,
logger=Logger(),
snapshot_path=None,
snapshot_name='snapshot',
snapshot_interval=1000,
dtype=torch.FloatTensor,
ltype=torch.LongTensor):
self.model = model
self.dataset = dataset
self.dataloader = None
self.lr = lr
self.weight_decay = weight_decay
self.clip = gradient_clipping
self.optimizer_type = optimizer
self.optimizer = self.optimizer_type(params=self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
self.logger = logger
self.logger.trainer = self
self.snapshot_path = snapshot_path
self.snapshot_name = snapshot_name
self.snapshot_interval = snapshot_interval
self.dtype = dtype
self.ltype = ltype
def train(self,
batch_size=32,
epochs=10,
continue_training_at_step=0):
self.model.train()
self.dataloader = torch.utils.data.DataLoader(self.dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
pin_memory=False)
step = continue_training_at_step
for current_epoch in range(epochs):
print("epoch", current_epoch)
tic = time.time()
for (x, target) in iter(self.dataloader):
x = Variable(x.type(self.dtype))
target = Variable(target.view(-1).type(self.ltype))
output = self.model(x)
loss = F.cross_entropy(output.squeeze(), target.squeeze())
self.optimizer.zero_grad()
loss.backward()
loss = loss.data[0]
if self.clip is not None:
torch.nn.utils.clip_grad_norm(self.model.parameters(), self.clip)
self.optimizer.step()
step += 1
# time step duration:
if step == 100:
toc = time.time()
print("one training step does take approximately " + str((toc - tic) * 0.01) + " seconds)")
if step % self.snapshot_interval == 0:
if self.snapshot_path is None:
continue
time_string = time.strftime("%Y-%m-%d_%H-%M-%S", time.gmtime())
torch.save(self.model, self.snapshot_path + '/' + self.snapshot_name + '_' + time_string)
self.logger.log(step, loss)
def validate(self):
self.model.eval()
self.dataset.train = False
total_loss = 0
accurate_classifications = 0
for (x, target) in iter(self.dataloader):
x = Variable(x.type(self.dtype))
target = Variable(target.view(-1).type(self.ltype))
output = self.model(x)
loss = F.cross_entropy(output.squeeze(), target.squeeze())
total_loss += loss.data[0]
predictions = torch.max(output, 1)[1].view(-1)
correct_pred = torch.eq(target, predictions)
accurate_classifications += torch.sum(correct_pred).data[0]
# print("validate model with " + str(len(self.dataloader.dataset)) + " samples")
# print("average loss: ", total_loss / len(self.dataloader))
avg_loss = total_loss / len(self.dataloader)
avg_accuracy = accurate_classifications / (len(self.dataset)*self.dataset.target_length)
self.dataset.train = True
self.model.train()
return avg_loss, avg_accuracy
def generate_audio(model,
length=8000,
temperatures=[0., 1.]):
samples = []
for temp in temperatures:
samples.append(model.generate_fast(length, temperature=temp))
samples = np.stack(samples, axis=0)
return samples