-
Notifications
You must be signed in to change notification settings - Fork 0
/
vqgan.py
63 lines (52 loc) · 2.95 KB
/
vqgan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
from encoder import Encoder
from decoder import Decoder
from codebook import Codebook
import torch.nn.functional as F
class VQGAN(nn.Module):
def __init__(self, args):
super(VQGAN, self).__init__()
# self.encoder = Encoder(args).to(device=args.device)
# self.decoder = Decoder(args).to(device=args.device)
self.encoder = Encoder(double_z=False, z_channels=256, resolution=256, in_channels=3, out_ch=3, ch=128,
ch_mult=[1, 1, 2, 2, 4], num_res_blocks=2, attn_resolutions=[16], dropout=0.0).to(
device=args.device)
self.decoder = Decoder(double_z=False, z_channels=256, resolution=256, in_channels=3, out_ch=3, ch=128,
ch_mult=[1, 1, 2, 2, 4], num_res_blocks=2, attn_resolutions=[16], dropout=0.0).to(device=args.device)
self.codebook = Codebook(args).to(device=args.device)
self.quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)
self.post_quant_conv = nn.Conv2d(args.latent_dim, args.latent_dim, 1).to(device=args.device)
def forward(self, imgs):
encoded_images = self.encoder(imgs)
quant_conv_encoded_images = self.quant_conv(encoded_images)
quant_conv_encoded_images = F.normalize(quant_conv_encoded_images,dim=1)
codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
post_quant_conv_mapping = self.post_quant_conv(codebook_mapping)
decoded_images = self.decoder(post_quant_conv_mapping)
return decoded_images, codebook_indices, q_loss
def encode(self, imgs):
encoded_images = self.encoder(imgs)
quant_conv_encoded_images = self.quant_conv(encoded_images)
quant_conv_encoded_images = F.normalize(quant_conv_encoded_images, dim=1)
codebook_mapping, codebook_indices, q_loss = self.codebook(quant_conv_encoded_images)
return codebook_mapping, codebook_indices, q_loss
def decode(self, z):
post_quant_conv_mapping = self.post_quant_conv(z)
decoded_images = self.decoder(post_quant_conv_mapping)
return decoded_images
def calculate_lambda(self, perceptual_loss, gan_loss):
last_layer = self.decoder.conv_out
last_layer_weight = last_layer.weight
perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]
lda = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
lda = torch.clamp(lda, 0, 1e4).detach()
return 0.8 * lda
@staticmethod
def adopt_weight(disc_factor, i, threshold, value=0.):
if i < threshold:
disc_factor = value
return disc_factor
def load_checkpoint(self, path):
self.load_state_dict(torch.load(path))