-
Notifications
You must be signed in to change notification settings - Fork 5
/
ssd_gan.py
248 lines (194 loc) · 8.47 KB
/
ssd_gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
"""
Implementation of Base SSD-GAN models.
"""
import torch
from torch_mimicry.nets.basemodel import basemodel
from torch_mimicry.modules import losses
import numpy as np
class SSD_Generator(basemodel.BaseModel):
r"""
Base class for a generic unconditional generator model.
Attributes:
nz (int): Noise dimension for upsampling.
ngf (int): Variable controlling generator feature map sizes.
bottom_width (int): Starting width for upsampling generator output to an image.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, nz, ngf, bottom_width, loss_type, **kwargs):
super().__init__(**kwargs)
self.nz = nz
self.ngf = ngf
self.bottom_width = bottom_width
self.loss_type = loss_type
# def generate_images(self, netG, num_images, device=None):
def generate_images(self, num_images, device=None):
r"""
Generates num_images randomly.
Args:
num_images (int): Number of images to generate
device (torch.device): Device to send images to.
Returns:
Tensor: A batch of generated images.
"""
if device is None:
device = self.device
noise = torch.randn((num_images, self.nz), device=device)
# fake_images = netG.forward(noise)
fake_images = self.forward(noise)
return fake_images
def compute_gan_loss(self, output):
r"""
Computes GAN loss for generator.
Args:
output (Tensor): A batch of output logits from the discriminator of shape (N, 1).
Returns:
Tensor: A batch of GAN losses for the generator.
"""
# Compute loss and backprop
if self.loss_type == "gan":
errG = losses.minimax_loss_gen(output)
elif self.loss_type == "ns":
errG = losses.ns_loss_gen(output)
elif self.loss_type == "hinge":
errG = losses.hinge_loss_gen(output)
elif self.loss_type == "wasserstein":
errG = losses.wasserstein_loss_gen(output)
else:
raise ValueError("Invalid loss_type {} selected.".format(
self.loss_type))
return errG
def train_step(self,
real_batch,
netD,
optG,
log_data,
device=None,
global_step=None,
**kwargs):
r"""
Takes one training step for G.
Args:
real_batch (Tensor): A batch of real images of shape (N, C, H, W).
Used for obtaining current batch size.
netD (nn.Module): Discriminator model for obtaining losses.
optG (Optimizer): Optimizer for updating generator's parameters.
log_data (dict): A dict mapping name to values for logging uses.
device (torch.device): Device to use for running the model.
global_step (int): Variable to sync training, logging and checkpointing.
Useful for dynamic changes to model amidst training.
Returns:
Returns MetricLog object containing updated logging variables after 1 training step.
"""
self.zero_grad()
# Get only batch size from real batch
batch_size = real_batch[0].shape[0]
# Produce fake images
fake_images = self.generate_images(num_images=batch_size,
device=device)
# Compute output logit of D thinking image real
out_spectral, out_spatial = netD(fake_images)
# Compute loss
out = 0.5 * out_spectral.detach() + 0.5 * out_spatial
errG = self.compute_gan_loss(out)
# Backprop and update gradients
errG.backward()
optG.step()
# Log statistics
log_data.add_metric('errG', errG, group='loss')
return log_data
class SSD_Discriminator(basemodel.BaseModel):
r"""
Base class for a generic unconditional discriminator model.
Attributes:
ndf (int): Variable controlling discriminator feature map sizes.
loss_type (str): Name of loss to use for GAN loss.
"""
def __init__(self, ndf, loss_type, **kwargs):
super().__init__(**kwargs)
self.ndf = ndf
self.loss_type = loss_type
def compute_gan_loss(self, output_real, output_fake):
r"""
Computes GAN loss for discriminator.
Args:
output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
Returns:
errD (Tensor): A batch of GAN losses for the discriminator.
"""
# Compute loss for D
if self.loss_type == "gan" or self.loss_type == "ns":
errD = losses.minimax_loss_dis(output_fake=output_fake,
output_real=output_real)
elif self.loss_type == "hinge":
errD = losses.hinge_loss_dis(output_fake=output_fake,
output_real=output_real)
elif self.loss_type == "wasserstein":
errD = losses.wasserstein_loss_dis(output_fake=output_fake,
output_real=output_real)
else:
raise ValueError("Invalid loss_type selected.")
return errD
def compute_probs(self, output_real, output_fake):
r"""
Computes probabilities from real/fake images logits.
Args:
output_real (Tensor): A batch of output logits of shape (N, 1) from real images.
output_fake (Tensor): A batch of output logits of shape (N, 1) from fake images.
Returns:
tuple: Average probabilities of real/fake image considered as real for the batch.
"""
D_x = torch.sigmoid(output_real).mean().item()
D_Gz = torch.sigmoid(output_fake).mean().item()
return D_x, D_Gz
def train_step(self,
real_batch,
netG,
optD,
log_data,
device=None,
global_step=None,
**kwargs):
r"""
Takes one training step for D.
Args:
real_batch (Tensor): A batch of real images of shape (N, C, H, W).
loss_type (str): Name of loss to use for GAN loss.
netG (nn.Module): Generator model for obtaining fake images.
optD (Optimizer): Optimizer for updating discriminator's parameters.
device (torch.device): Device to use for running the model.
log_data (dict): A dict mapping name to values for logging uses.
global_step (int): Variable to sync training, logging and checkpointing.
Useful for dynamic changes to model amidst training.
Returns:
MetricLog: Returns MetricLog object containing updated logging variables after 1 training step.
"""
self.zero_grad()
real_images, real_labels = real_batch
batch_size = real_images.shape[0] # Match batch sizes for last iter
# Produce logits for real images
out_spectral_real, out_spatial_real = self.forward(real_images)
# Produce fake images
fake_images = netG.generate_images(num_images=batch_size,
device=device).detach()
# Produce logits for fake images
out_spectral_fake, out_spatial_fake = self.forward(fake_images)
# Compute loss for D
errC = self.compute_gan_loss(output_real=out_spectral_real,
output_fake=out_spectral_fake)
out_real = 0.5 * out_spectral_real.detach() + 0.5 * out_spatial_real
out_fake = 0.5 * out_spectral_fake.detach() + 0.5 * out_spatial_fake
errD = self.compute_gan_loss(output_real=out_real,
output_fake=out_fake)
# Backprop and update gradients
errD_total = errD + errC
errD_total.backward()
optD.step()
# Compute probabilities
D_x, D_Gz = out_real.mean().item(), out_fake.mean().item()
# Log statistics for D once out of loop
log_data.add_metric('errD', errD.item(), group='loss')
log_data.add_metric('errC', errC.item(), group='loss')
log_data.add_metric('D(x)', D_x, group='prob')
log_data.add_metric('D(G(z))', D_Gz, group='prob')
return log_data