Skip to content

Commit

Permalink
feat: add discriminator UNet and CutMix as one option
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Nov 10, 2023
1 parent 5ef5f8b commit dc69f51
Show file tree
Hide file tree
Showing 10 changed files with 971 additions and 13 deletions.
38 changes: 30 additions & 8 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def forward_GAN(self):

if self.use_temporal:
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
self.compute_temporal_fake(objective_domain="A")

Expand Down Expand Up @@ -401,7 +402,6 @@ def compute_D_loss(self):
loss_name,
loss_value,
)

self.loss_D_tot += loss_value

def compute_G_loss_GAN_generic(
Expand Down Expand Up @@ -479,7 +479,6 @@ def compute_G_loss_GAN(self):
loss_name,
loss_value,
)

self.loss_G_tot += loss_value

if self.opt.train_temporal_criterion:
Expand Down Expand Up @@ -531,6 +530,9 @@ def set_discriminators_info(self):
else:
train_gan_mode = self.opt.train_gan_mode

if "une_discriminator_mha" in discriminator_name:
train_use_cutmix = self.opt.train_use_cutmix

if "projected" in discriminator_name:
dataaug_D_diffusion = self.opt.dataaug_D_diffusion
dataaug_D_diffusion_every = self.opt.dataaug_D_diffusion_every
Expand Down Expand Up @@ -562,6 +564,26 @@ def set_discriminators_info(self):
real_name = "temporal_real"
compute_every = self.opt.D_temporal_every

elif "unet_discriminator_mha" in discriminator_name:
loss_calculator = loss.DualDiscriminatorGANLoss(
netD=getattr(self, "net" + discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
train_use_cutmix=self.opt.train_use_cutmix,
)
fake_name = None
real_name = None
compute_every = 1

else:
fake_name = None
real_name = None
Expand Down Expand Up @@ -598,12 +620,6 @@ def set_discriminators_info(self):
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)

setattr(
self,
loss_calculator_name,
loss_calculator,
)

if "depth" in discriminator_name:
fake_name = "fake_depth"
real_name = "real_depth"
Expand All @@ -614,6 +630,12 @@ def set_discriminators_info(self):
fake_name = "fake_sam"
real_name = "real_sam"

setattr(
self,
loss_calculator_name,
loss_calculator,
)

self.objects_to_update.append(getattr(self, loss_calculator_name))

self.discriminators.append(
Expand Down
32 changes: 32 additions & 0 deletions models/gan_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
import torch.nn as nn
import functools
from torch.optim import lr_scheduler
Expand Down Expand Up @@ -41,6 +42,9 @@
UNet as UNet_mha,
UViT as UViT,
)
from .modules.unet_generator_attn.unet_discriminator_attn import (
UNet as UNet_discriminator_mha,
)


def define_G(
Expand Down Expand Up @@ -244,6 +248,8 @@ def define_G(
def define_D(
D_netDs,
model_input_nc,
model_output_nc,
D_num_downs,
D_ndf,
D_n_layers,
D_norm,
Expand All @@ -266,14 +272,21 @@ def define_D(
f_s_semantic_nclasses,
model_depth_network,
train_feat_wavelet,
G_unet_mha_num_head_channels,
G_unet_mha_res_blocks,
G_unet_mha_channel_mults,
G_unet_mha_norm_layer,
G_unet_mha_group_norm_size,
**unused_options
):

"""Create a discriminator
Parameters:
model_input_nc (int) -- the number of channels in input images
model_output_nc (int) -- the number of channels in output images
D_ndf (int) -- the number of filters in the first conv layer
num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
netD (str) -- the architecture's name: basic | n_layers | pixel
D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
D_norm (str) -- the type of normalization layers used in the network.
Expand Down Expand Up @@ -432,6 +445,25 @@ def define_D(
)
return_nets[netD] = init_net(net, model_init_type, model_init_gain)

elif netD == "unet_discriminator_mha":
net = UNet_discriminator_mha(
image_size=data_crop_size,
in_channel=model_input_nc,
inner_channel=D_ndf,
cond_embed_dim=D_ndf * 4,
out_channel=model_output_nc,
res_blocks=G_unet_mha_res_blocks,
attn_res=[16],
channel_mults=G_unet_mha_channel_mults, # e.g. (1, 2, 4, 8)
num_head_channels=G_unet_mha_num_head_channels,
tanh=True,
n_timestep_train=0, # unused
n_timestep_test=0, # unused
norm=G_unet_mha_norm_layer,
group_norm_size=G_unet_mha_group_norm_size,
)
return_nets[netD] = net # init_net(net, model_init_type, model_init_gain)

else:
raise NotImplementedError(
"Discriminator model name [%s] is not recognized" % netD
Expand Down
4 changes: 3 additions & 1 deletion models/modules/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import functools

import torch
import numpy as np
from torch import nn
from torch.nn import functional as F

from .utils import spectral_norm, normal_init

torch.autograd.set_detect_anomaly(True)


class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
Expand Down
135 changes: 133 additions & 2 deletions models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F
import random
import math

# import numpy as np
import numpy as np
from util.cutmix import CutMix


class GANLoss(nn.Module):
Expand Down Expand Up @@ -394,6 +394,137 @@ def compute_loss_G(self, netD, real, fake):
return loss_G


class DualDiscriminatorGANLoss(DiscriminatorLoss):
"""
unet loss integrating encoder-decoder losses with optional CutMix augmentation from reference https://arxiv.org/abs/2002.12655
"""

def __init__(
self,
netD,
device,
dataaug_APA_p,
dataaug_APA_target,
train_batch_size,
dataaug_APA_nimg,
dataaug_APA_every,
dataaug_D_label_smooth,
train_gan_mode,
dataaug_APA,
dataaug_D_diffusion,
dataaug_D_diffusion_every,
train_use_cutmix,
):
super().__init__(
netD,
device,
dataaug_APA_p,
dataaug_APA_target,
train_batch_size,
dataaug_APA_nimg,
dataaug_APA_every,
dataaug_APA,
dataaug_D_diffusion,
dataaug_D_diffusion_every,
)
if dataaug_D_label_smooth:
target_real_label = 0.9
else:
target_real_label = 1.0

self.gan_mode = train_gan_mode
self.train_use_cutmix = train_use_cutmix

self.criterionGAN = GANLoss(
self.gan_mode, target_real_label=target_real_label
).to(self.device)

def compute_loss_D(self, netD, real, fake, fake_2):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
super().compute_loss_D(netD, real, fake, fake_2)

# Real
pred_real_pixel, pred_real_bottleneck = netD(self.real)

loss_pred_real_pixel = self.criterionGAN(pred_real_pixel, True)
loss_pred_real_bottleneck = self.criterionGAN(pred_real_bottleneck, True)
self.loss_D_real = loss_pred_real_pixel + loss_pred_real_bottleneck

# Fake

lambda_loss = 0.5

def cutmix_real_fake_pairwise(real, fake):
assert (
self.real.size() == self.fake.size()
), "Real and fake images should have the same dimensions."
masks = torch.stack(
[
CutMix(self.real.size(2)).to(self.real.device)
for _ in range(self.real.size(0))
]
)
masks = masks.unsqueeze(1)
mixed_images = masks * fake + (1 - masks) * real
return mixed_images, masks

if self.train_use_cutmix:
lambda_cutmix = 0.5
fake_input = self.fake.detach()
pred_fake_pixel, pred_fake_bottleneck = netD(fake_input)

cutmix_img, label_masks = cutmix_real_fake_pairwise(self.real, fake_input)
fake_cutmix_input = cutmix_img.detach()
pred_cutmix_fake_pixel, pred_cutmix_fake_bottleneck = netD(
fake_cutmix_input
)
cutmix_pixel_label = label_masks.repeat(1, 3, 1, 1)

consistent_pred_pixel = torch.mul(
pred_fake_pixel, 1 - cutmix_pixel_label
) + torch.mul(pred_real_pixel, cutmix_pixel_label)

loss_cutmix_pixel = torch.norm(
pred_cutmix_fake_pixel - consistent_pred_pixel, p=2
).pow(2)

loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False)
loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False)
loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel
loss_D = (
self.loss_D_real + loss_D_fake
) * lambda_loss + loss_cutmix_pixel * lambda_cutmix

else:
fake_input = self.fake.detach()
pred_fake_pixel, pred_fake_bottleneck = netD(fake_input)

loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False)
loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False)
loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel
loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss

return loss_D

def compute_loss_G(self, netD, real, fake):

super().compute_loss_G(netD, real, fake)
pred_fake_pixel, pred_fake_bottleneck = netD(self.fake)

loss_G_pixel = self.criterionGAN(pred_fake_pixel, True, relu=False)
loss_G_bottleneck = self.criterionGAN(pred_fake_bottleneck, True, relu=False)

loss_D_fake = loss_G_pixel + loss_G_bottleneck
return loss_D_fake


class MultiScaleDiffusionLoss(nn.Module):
"""
Multiscale diffusion loss such as in 2301.11093.
Expand Down
Loading

0 comments on commit dc69f51

Please sign in to comment.