-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsang_wrapper.py
130 lines (95 loc) · 4.34 KB
/
sang_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import time
import torch.optim as optim
from sang_utils import *
from sang_gan import *
from sang_plot import *
def wrapper_(opt):
# for save file name
epochs = opt['num_epochs']
lr = opt['learning_rate']
bs = opt['batch_size']
if opt['dataset'] == "celebA":
data_dir = 'resized_celebA' # this path depends on your computer
train_loader = get_celebA_loader(data_dir, bs, opt['img_size'])
elif opt['dataset'] == "cifar10":
train_loader = get_cifar10_loader(bs, opt['img_size'])
G = generator()
D = discriminator()
# Weight initialization
G.weight_init()
D.weight_init()
# put G and D in cuda
G.cuda()
D.cuda()
# Adam optimizer for WGAN-GP
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(opt['b1'], opt['b2']))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(opt['b1'], opt['b2']))
# lr_sche_G = torch.optim.lr_scheduler.StepLR(G_optimizer, step_size=10, gamma=0.99)
# lr_sche_D = torch.optim.lr_scheduler.StepLR(D_optimizer, step_size=10, gamma=0.99)
save_path = f'Cifar10_WGAN-GP_epoch_{epochs}_lr_{lr}_batches_{bs}'
os.makedirs(save_path, exist_ok=True)
os.makedirs(os.path.join(save_path, 'Random_results'), exist_ok=True)
os.makedirs(os.path.join(save_path, 'Fixed_results'), exist_ok=True)
train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['total_ptime'] = []
# fixed noise
fixed_z_ = torch.randn((5 * 5, opt['latent_dim'])).view(-1, opt['latent_dim'], 1, 1)
with torch.no_grad():
fixed_z_ = Variable(fixed_z_.cuda())
print('Training start!')
start_time = time.time()
for epoch in range(epochs):
D_losses = []
G_losses = []
epoch_start_time = time.time()
for i, (x_, _) in enumerate(train_loader):
# Configure input
real_image = Variable(x_.cuda())
# train discriminator D
D_optimizer.zero_grad()
mini_batch = real_image.shape[0] # image shape
z = Variable(torch.randn((mini_batch, opt['latent_dim'])).view(-1, opt['latent_dim'], 1, 1)) # declare noise z = (image_shape, 100, 1, 1)
z = Variable(z.cuda())
# Generate fake image
fake_image = G(z)
real_validity = D(real_image)
fake_validity = D(fake_image)
gradient_penalty = calculate_gradient_penalty(D, real_image, fake_image, opt['lambda_gp'])
D_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + gradient_penalty
D_loss.backward()
D_optimizer.step()
# lr_sche_D.step()
D_losses.append(D_loss.item())
G_optimizer.zero_grad()
if i % opt['n_critic'] == 0:
# train generator G
fake_image = G(z)
fake_validity = D(fake_image)
G_loss = -torch.mean(fake_validity)
G_loss.backward()
G_optimizer.step()
# lr_sche_G.step()
G_losses.append(G_loss.item())
# For random generator images
z_ = torch.randn((5 * 5, opt['latent_dim'])).view(-1, opt['latent_dim'], 1, 1)
with torch.no_grad():
z_ = Variable(z_.cuda())
epoch_end_time = time.time()
per_epoch_ptime = epoch_end_time - epoch_start_time
print(f'[{epoch+1}/{epochs}] - epoch time: {per_epoch_ptime}, loss_D: {torch.mean(torch.FloatTensor(D_losses))}, loss_G: {torch.mean(torch.FloatTensor(G_losses))}')
p = os.path.join(save_path, 'Random_results/CelebA_WGAN-GP_'+str(epoch+1)+'.png')
fixed_p = os.path.join(save_path, 'Fixed_results/CelebA_WGAN-GP_'+str(epoch+1)+'.png')
show_result(G, (epoch + 1), z_, save=True, path=p)
show_result(G, (epoch + 1), fixed_z_, save=True, path=fixed_p)
train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)
print(f'Total time: {total_ptime}')
print("Training finish!... save training results")
show_train_hist(train_hist, save=True, path=save_path + '/CelebA_WGAN-GP_train_hist.png')
make_animation(epoch, save_path)