Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat discriminator unet mha #580

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def forward_GAN(self):

if self.use_temporal:
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
self.compute_temporal_fake(objective_domain="A")

Expand Down Expand Up @@ -401,7 +402,6 @@ def compute_D_loss(self):
loss_name,
loss_value,
)

self.loss_D_tot += loss_value

def compute_G_loss_GAN_generic(
Expand Down Expand Up @@ -479,7 +479,6 @@ def compute_G_loss_GAN(self):
loss_name,
loss_value,
)

self.loss_G_tot += loss_value

if self.opt.train_temporal_criterion:
Expand Down Expand Up @@ -531,6 +530,9 @@ def set_discriminators_info(self):
else:
train_gan_mode = self.opt.train_gan_mode

if "une_discriminator_mha" in discriminator_name:
train_use_cutmix = self.opt.train_use_cutmix

if "projected" in discriminator_name:
dataaug_D_diffusion = self.opt.dataaug_D_diffusion
dataaug_D_diffusion_every = self.opt.dataaug_D_diffusion_every
Expand Down Expand Up @@ -562,6 +564,26 @@ def set_discriminators_info(self):
real_name = "temporal_real"
compute_every = self.opt.D_temporal_every

elif "unet_discriminator_mha" in discriminator_name:
loss_calculator = loss.DualDiscriminatorGANLoss(
netD=getattr(self, "net" + discriminator_name),
device=self.device,
dataaug_APA_p=self.opt.dataaug_APA_p,
dataaug_APA_target=self.opt.dataaug_APA_target,
train_batch_size=self.opt.train_batch_size,
dataaug_APA_nimg=self.opt.dataaug_APA_nimg,
dataaug_APA_every=self.opt.dataaug_APA_every,
dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth,
train_gan_mode=train_gan_mode,
dataaug_APA=self.opt.dataaug_APA,
dataaug_D_diffusion=dataaug_D_diffusion,
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
train_use_cutmix=self.opt.train_use_cutmix,
)
fake_name = None
real_name = None
compute_every = 1

else:
fake_name = None
real_name = None
Expand Down Expand Up @@ -598,12 +620,6 @@ def set_discriminators_info(self):
dataaug_D_diffusion_every=dataaug_D_diffusion_every,
)

setattr(
self,
loss_calculator_name,
loss_calculator,
)

if "depth" in discriminator_name:
fake_name = "fake_depth"
real_name = "real_depth"
Expand All @@ -614,6 +630,12 @@ def set_discriminators_info(self):
fake_name = "fake_sam"
real_name = "real_sam"

setattr(
self,
loss_calculator_name,
loss_calculator,
)

self.objects_to_update.append(getattr(self, loss_calculator_name))

self.discriminators.append(
Expand Down
32 changes: 32 additions & 0 deletions models/gan_networks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import torch
import torch.nn as nn
import functools
from torch.optim import lr_scheduler
Expand Down Expand Up @@ -41,6 +42,9 @@
UNet as UNet_mha,
UViT as UViT,
)
from .modules.unet_generator_attn.unet_discriminator_attn import (
UNet as UNet_discriminator_mha,
)


def define_G(
Expand Down Expand Up @@ -244,6 +248,8 @@ def define_G(
def define_D(
D_netDs,
model_input_nc,
model_output_nc,
D_num_downs,
D_ndf,
D_n_layers,
D_norm,
Expand All @@ -266,14 +272,21 @@ def define_D(
f_s_semantic_nclasses,
model_depth_network,
train_feat_wavelet,
G_unet_mha_num_head_channels,
G_unet_mha_res_blocks,
G_unet_mha_channel_mults,
G_unet_mha_norm_layer,
G_unet_mha_group_norm_size,
**unused_options
):

"""Create a discriminator

Parameters:
model_input_nc (int) -- the number of channels in input images
model_output_nc (int) -- the number of channels in output images
D_ndf (int) -- the number of filters in the first conv layer
num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck
netD (str) -- the architecture's name: basic | n_layers | pixel
D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
D_norm (str) -- the type of normalization layers used in the network.
Expand Down Expand Up @@ -432,6 +445,25 @@ def define_D(
)
return_nets[netD] = init_net(net, model_init_type, model_init_gain)

elif netD == "unet_discriminator_mha":
net = UNet_discriminator_mha(
image_size=data_crop_size,
in_channel=model_input_nc,
inner_channel=D_ndf,
cond_embed_dim=D_ndf * 4,
out_channel=model_output_nc,
res_blocks=G_unet_mha_res_blocks,
attn_res=[16],
channel_mults=G_unet_mha_channel_mults, # e.g. (1, 2, 4, 8)
num_head_channels=G_unet_mha_num_head_channels,
tanh=True,
n_timestep_train=0, # unused
n_timestep_test=0, # unused
norm=G_unet_mha_norm_layer,
group_norm_size=G_unet_mha_group_norm_size,
)
return_nets[netD] = net # init_net(net, model_init_type, model_init_gain)

else:
raise NotImplementedError(
"Discriminator model name [%s] is not recognized" % netD
Expand Down
4 changes: 3 additions & 1 deletion models/modules/discriminators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import functools

import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed I believe.

Copy link
Collaborator Author

@wr0124 wr0124 Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, i'll check on it and delete it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i remove it, my code does not work.

import numpy as np
from torch import nn
from torch.nn import functional as F

from .utils import spectral_norm, normal_init

torch.autograd.set_detect_anomaly(True)


class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
Expand Down
135 changes: 133 additions & 2 deletions models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F
import random
import math

# import numpy as np
import numpy as np
from util.cutmix import CutMix


class GANLoss(nn.Module):
Expand Down Expand Up @@ -394,6 +394,137 @@ def compute_loss_G(self, netD, real, fake):
return loss_G


class DualDiscriminatorGANLoss(DiscriminatorLoss):
"""
unet loss integrating encoder-decoder losses with optional CutMix augmentation from reference https://arxiv.org/abs/2002.12655
"""

def __init__(
self,
netD,
device,
dataaug_APA_p,
dataaug_APA_target,
train_batch_size,
dataaug_APA_nimg,
dataaug_APA_every,
dataaug_D_label_smooth,
train_gan_mode,
dataaug_APA,
dataaug_D_diffusion,
dataaug_D_diffusion_every,
train_use_cutmix,
):
super().__init__(
netD,
device,
dataaug_APA_p,
dataaug_APA_target,
train_batch_size,
dataaug_APA_nimg,
dataaug_APA_every,
dataaug_APA,
dataaug_D_diffusion,
dataaug_D_diffusion_every,
)
if dataaug_D_label_smooth:
target_real_label = 0.9
else:
target_real_label = 1.0

self.gan_mode = train_gan_mode
self.train_use_cutmix = train_use_cutmix

self.criterionGAN = GANLoss(
self.gan_mode, target_real_label=target_real_label
).to(self.device)

def compute_loss_D(self, netD, real, fake, fake_2):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
super().compute_loss_D(netD, real, fake, fake_2)

# Real
pred_real_pixel, pred_real_bottleneck = netD(self.real)

loss_pred_real_pixel = self.criterionGAN(pred_real_pixel, True)
loss_pred_real_bottleneck = self.criterionGAN(pred_real_bottleneck, True)
self.loss_D_real = loss_pred_real_pixel + loss_pred_real_bottleneck

# Fake

lambda_loss = 0.5

def cutmix_real_fake_pairwise(real, fake):
assert (
self.real.size() == self.fake.size()
), "Real and fake images should have the same dimensions."
masks = torch.stack(
[
CutMix(self.real.size(2)).to(self.real.device)
for _ in range(self.real.size(0))
]
)
masks = masks.unsqueeze(1)
mixed_images = masks * fake + (1 - masks) * real
return mixed_images, masks

if self.train_use_cutmix:
lambda_cutmix = 0.5
fake_input = self.fake.detach()
pred_fake_pixel, pred_fake_bottleneck = netD(fake_input)

cutmix_img, label_masks = cutmix_real_fake_pairwise(self.real, fake_input)
fake_cutmix_input = cutmix_img.detach()
pred_cutmix_fake_pixel, pred_cutmix_fake_bottleneck = netD(
fake_cutmix_input
)
cutmix_pixel_label = label_masks.repeat(1, 3, 1, 1)

consistent_pred_pixel = torch.mul(
pred_fake_pixel, 1 - cutmix_pixel_label
) + torch.mul(pred_real_pixel, cutmix_pixel_label)

loss_cutmix_pixel = torch.norm(
pred_cutmix_fake_pixel - consistent_pred_pixel, p=2
).pow(2)

loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False)
loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False)
loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel
loss_D = (
self.loss_D_real + loss_D_fake
) * lambda_loss + loss_cutmix_pixel * lambda_cutmix

else:
fake_input = self.fake.detach()
pred_fake_pixel, pred_fake_bottleneck = netD(fake_input)

loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False)
loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False)
loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel
loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss

return loss_D

def compute_loss_G(self, netD, real, fake):

super().compute_loss_G(netD, real, fake)
pred_fake_pixel, pred_fake_bottleneck = netD(self.fake)

loss_G_pixel = self.criterionGAN(pred_fake_pixel, True, relu=False)
loss_G_bottleneck = self.criterionGAN(pred_fake_bottleneck, True, relu=False)

loss_D_fake = loss_G_pixel + loss_G_bottleneck
return loss_D_fake


class MultiScaleDiffusionLoss(nn.Module):
"""
Multiscale diffusion loss such as in 2301.11093.
Expand Down
Loading
Loading