diff --git a/models/base_gan_model.py b/models/base_gan_model.py index ce7f775b8..9a4760cc0 100644 --- a/models/base_gan_model.py +++ b/models/base_gan_model.py @@ -208,6 +208,7 @@ def forward_GAN(self): if self.use_temporal: self.compute_temporal_fake(objective_domain="B") + if hasattr(self, "netG_B"): self.compute_temporal_fake(objective_domain="A") @@ -401,7 +402,6 @@ def compute_D_loss(self): loss_name, loss_value, ) - self.loss_D_tot += loss_value def compute_G_loss_GAN_generic( @@ -479,7 +479,6 @@ def compute_G_loss_GAN(self): loss_name, loss_value, ) - self.loss_G_tot += loss_value if self.opt.train_temporal_criterion: @@ -531,6 +530,9 @@ def set_discriminators_info(self): else: train_gan_mode = self.opt.train_gan_mode + if "une_discriminator_mha" in discriminator_name: + train_use_cutmix = self.opt.train_use_cutmix + if "projected" in discriminator_name: dataaug_D_diffusion = self.opt.dataaug_D_diffusion dataaug_D_diffusion_every = self.opt.dataaug_D_diffusion_every @@ -562,6 +564,26 @@ def set_discriminators_info(self): real_name = "temporal_real" compute_every = self.opt.D_temporal_every + elif "unet_discriminator_mha" in discriminator_name: + loss_calculator = loss.DualDiscriminatorGANLoss( + netD=getattr(self, "net" + discriminator_name), + device=self.device, + dataaug_APA_p=self.opt.dataaug_APA_p, + dataaug_APA_target=self.opt.dataaug_APA_target, + train_batch_size=self.opt.train_batch_size, + dataaug_APA_nimg=self.opt.dataaug_APA_nimg, + dataaug_APA_every=self.opt.dataaug_APA_every, + dataaug_D_label_smooth=self.opt.dataaug_D_label_smooth, + train_gan_mode=train_gan_mode, + dataaug_APA=self.opt.dataaug_APA, + dataaug_D_diffusion=dataaug_D_diffusion, + dataaug_D_diffusion_every=dataaug_D_diffusion_every, + train_use_cutmix=self.opt.train_use_cutmix, + ) + fake_name = None + real_name = None + compute_every = 1 + else: fake_name = None real_name = None @@ -598,12 +620,6 @@ def set_discriminators_info(self): dataaug_D_diffusion_every=dataaug_D_diffusion_every, ) - setattr( - self, - loss_calculator_name, - loss_calculator, - ) - if "depth" in discriminator_name: fake_name = "fake_depth" real_name = "real_depth" @@ -614,6 +630,12 @@ def set_discriminators_info(self): fake_name = "fake_sam" real_name = "real_sam" + setattr( + self, + loss_calculator_name, + loss_calculator, + ) + self.objects_to_update.append(getattr(self, loss_calculator_name)) self.discriminators.append( diff --git a/models/gan_networks.py b/models/gan_networks.py index af8266392..a36619681 100644 --- a/models/gan_networks.py +++ b/models/gan_networks.py @@ -1,4 +1,5 @@ import os +import torch import torch.nn as nn import functools from torch.optim import lr_scheduler @@ -41,6 +42,9 @@ UNet as UNet_mha, UViT as UViT, ) +from .modules.unet_generator_attn.unet_discriminator_attn import ( + UNet as UNet_discriminator_mha, +) def define_G( @@ -244,6 +248,8 @@ def define_G( def define_D( D_netDs, model_input_nc, + model_output_nc, + D_num_downs, D_ndf, D_n_layers, D_norm, @@ -266,6 +272,11 @@ def define_D( f_s_semantic_nclasses, model_depth_network, train_feat_wavelet, + G_unet_mha_num_head_channels, + G_unet_mha_res_blocks, + G_unet_mha_channel_mults, + G_unet_mha_norm_layer, + G_unet_mha_group_norm_size, **unused_options ): @@ -273,7 +284,9 @@ def define_D( Parameters: model_input_nc (int) -- the number of channels in input images + model_output_nc (int) -- the number of channels in output images D_ndf (int) -- the number of filters in the first conv layer + num_downs(int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 at the bottleneck netD (str) -- the architecture's name: basic | n_layers | pixel D_n_layers (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' D_norm (str) -- the type of normalization layers used in the network. @@ -432,6 +445,25 @@ def define_D( ) return_nets[netD] = init_net(net, model_init_type, model_init_gain) + elif netD == "unet_discriminator_mha": + net = UNet_discriminator_mha( + image_size=data_crop_size, + in_channel=model_input_nc, + inner_channel=D_ndf, + cond_embed_dim=D_ndf * 4, + out_channel=model_output_nc, + res_blocks=G_unet_mha_res_blocks, + attn_res=[16], + channel_mults=G_unet_mha_channel_mults, # e.g. (1, 2, 4, 8) + num_head_channels=G_unet_mha_num_head_channels, + tanh=True, + n_timestep_train=0, # unused + n_timestep_test=0, # unused + norm=G_unet_mha_norm_layer, + group_norm_size=G_unet_mha_group_norm_size, + ) + return_nets[netD] = net # init_net(net, model_init_type, model_init_gain) + else: raise NotImplementedError( "Discriminator model name [%s] is not recognized" % netD diff --git a/models/modules/discriminators.py b/models/modules/discriminators.py index 5662df779..3eca009b3 100644 --- a/models/modules/discriminators.py +++ b/models/modules/discriminators.py @@ -1,11 +1,13 @@ import functools - +import torch import numpy as np from torch import nn from torch.nn import functional as F from .utils import spectral_norm, normal_init +torch.autograd.set_detect_anomaly(True) + class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" diff --git a/models/modules/loss.py b/models/modules/loss.py index 13d810f36..08fa16b88 100644 --- a/models/modules/loss.py +++ b/models/modules/loss.py @@ -4,8 +4,8 @@ import torch.nn.functional as F import random import math - -# import numpy as np +import numpy as np +from util.cutmix import CutMix class GANLoss(nn.Module): @@ -394,6 +394,137 @@ def compute_loss_G(self, netD, real, fake): return loss_G +class DualDiscriminatorGANLoss(DiscriminatorLoss): + """ + unet loss integrating encoder-decoder losses with optional CutMix augmentation from reference https://arxiv.org/abs/2002.12655 + """ + + def __init__( + self, + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_D_label_smooth, + train_gan_mode, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + train_use_cutmix, + ): + super().__init__( + netD, + device, + dataaug_APA_p, + dataaug_APA_target, + train_batch_size, + dataaug_APA_nimg, + dataaug_APA_every, + dataaug_APA, + dataaug_D_diffusion, + dataaug_D_diffusion_every, + ) + if dataaug_D_label_smooth: + target_real_label = 0.9 + else: + target_real_label = 1.0 + + self.gan_mode = train_gan_mode + self.train_use_cutmix = train_use_cutmix + + self.criterionGAN = GANLoss( + self.gan_mode, target_real_label=target_real_label + ).to(self.device) + + def compute_loss_D(self, netD, real, fake, fake_2): + """Calculate GAN loss for the discriminator + Parameters: + netD (network) -- the discriminator D + real (tensor array) -- real images + fake (tensor array) -- images generated by a generator + Return the discriminator loss. + We also call loss_D.backward() to calculate the gradients. + """ + super().compute_loss_D(netD, real, fake, fake_2) + + # Real + pred_real_pixel, pred_real_bottleneck = netD(self.real) + + loss_pred_real_pixel = self.criterionGAN(pred_real_pixel, True) + loss_pred_real_bottleneck = self.criterionGAN(pred_real_bottleneck, True) + self.loss_D_real = loss_pred_real_pixel + loss_pred_real_bottleneck + + # Fake + + lambda_loss = 0.5 + + def cutmix_real_fake_pairwise(real, fake): + assert ( + self.real.size() == self.fake.size() + ), "Real and fake images should have the same dimensions." + masks = torch.stack( + [ + CutMix(self.real.size(2)).to(self.real.device) + for _ in range(self.real.size(0)) + ] + ) + masks = masks.unsqueeze(1) + mixed_images = masks * fake + (1 - masks) * real + return mixed_images, masks + + if self.train_use_cutmix: + lambda_cutmix = 0.5 + fake_input = self.fake.detach() + pred_fake_pixel, pred_fake_bottleneck = netD(fake_input) + + cutmix_img, label_masks = cutmix_real_fake_pairwise(self.real, fake_input) + fake_cutmix_input = cutmix_img.detach() + pred_cutmix_fake_pixel, pred_cutmix_fake_bottleneck = netD( + fake_cutmix_input + ) + cutmix_pixel_label = label_masks.repeat(1, 3, 1, 1) + + consistent_pred_pixel = torch.mul( + pred_fake_pixel, 1 - cutmix_pixel_label + ) + torch.mul(pred_real_pixel, cutmix_pixel_label) + + loss_cutmix_pixel = torch.norm( + pred_cutmix_fake_pixel - consistent_pred_pixel, p=2 + ).pow(2) + + loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False) + loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False) + loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel + loss_D = ( + self.loss_D_real + loss_D_fake + ) * lambda_loss + loss_cutmix_pixel * lambda_cutmix + + else: + fake_input = self.fake.detach() + pred_fake_pixel, pred_fake_bottleneck = netD(fake_input) + + loss_D_fake_pixel = self.criterionGAN(pred_fake_pixel, False) + loss_D_fake_bottleneck = self.criterionGAN(pred_fake_bottleneck, False) + loss_D_fake = loss_D_fake_bottleneck + loss_D_fake_pixel + loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss + + return loss_D + + def compute_loss_G(self, netD, real, fake): + + super().compute_loss_G(netD, real, fake) + pred_fake_pixel, pred_fake_bottleneck = netD(self.fake) + + loss_G_pixel = self.criterionGAN(pred_fake_pixel, True, relu=False) + loss_G_bottleneck = self.criterionGAN(pred_fake_bottleneck, True, relu=False) + + loss_D_fake = loss_G_pixel + loss_G_bottleneck + return loss_D_fake + + class MultiScaleDiffusionLoss(nn.Module): """ Multiscale diffusion loss such as in 2301.11093. diff --git a/models/modules/unet_generator_attn/unet_discriminator_attn.py b/models/modules/unet_generator_attn/unet_discriminator_attn.py new file mode 100644 index 000000000..aedc3200c --- /dev/null +++ b/models/modules/unet_generator_attn/unet_discriminator_attn.py @@ -0,0 +1,713 @@ +from abc import abstractmethod +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from .unet_attn_utils import ( + checkpoint, + zero_module, + normalization, + normalization1d, + count_flops_attn, +) + + +class EmbedBlock(nn.Module): + """ + Any module where forward() takes embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` embeddings. + """ + + +class EmbedSequential(nn.Sequential, EmbedBlock): + """ + A sequential module that passes embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, EmbedBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + + """ + + def __init__( + self, channels, use_conv, out_channel=None, efficient=False, freq_space=False + ): + super().__init__() + self.channels = channels + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.freq_space = freq_space + + if freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channel, 3, padding=1) + self.efficient = efficient + + def forward(self, x): + if self.freq_space: + x = self.iwt(x) + + assert x.shape[1] == self.channels + if not self.efficient: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + if self.efficient: # if efficient, we do the interpolation after the conv + x = F.interpolate(x, scale_factor=2, mode="nearest") + + if self.freq_space: + x = self.dwt(x) + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channel=None, freq_space=False): + super().__init__() + self.channels = channels + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + + stride = 2 + if use_conv: + self.op = nn.Conv2d( + self.channels, self.out_channel, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channel + self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) + + def forward(self, x): + if self.freq_space: + x = self.iwt(x) + + assert x.shape[1] == self.channels + opx = self.op(x) + + if self.freq_space: + opx = self.dwt(opx) + + return opx + + +class ResBlock(EmbedBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of embedding channels. + :param dropout: the rate of dropout. + :param out_channel: if specified, the number of out channels. + :param use_conv: if True and out_channel is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + norm, + out_channel=None, + use_conv=False, + use_scale_shift_norm=False, + use_checkpoint=False, + up=False, + down=False, + efficient=False, + freq_space=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.up = up + self.efficient = efficient + self.freq_space = freq_space + self.updown = up or down + + self.in_layers = nn.Sequential( + normalization(self.channels, norm), + torch.nn.SiLU(), + nn.Conv2d(self.channels, self.out_channel, 3, padding=1), + ) + + if up: + self.h_upd = Upsample(channels, False, freq_space=self.freq_space) + self.x_upd = Upsample(channels, False, freq_space=self.freq_space) + elif down: + self.h_upd = Downsample(channels, False, freq_space=self.freq_space) + self.x_upd = Downsample(channels, False, freq_space=self.freq_space) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + torch.nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channel if use_scale_shift_norm else self.out_channel, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channel, norm), + torch.nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channel, self.out_channel, 3, padding=1)), + ) + + if self.out_channel == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv2d(channels, self.out_channel, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channel, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + + h = in_rest(x) + + if self.efficient and self.up: + h = in_conv(h) + h = self.h_upd(h) + x = self.x_upd(x) + else: + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out.unsqueeze(-1) + # emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + skipw = 1.0 + if self.efficient: + skipw = 1.0 / math.sqrt(2) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + use_transformer=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.use_transformer = use_transformer + self.norm = normalization1d(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + def _forward(self, x): + b, c, *spatial = x.shape + if self.use_transformer: + x = x.reshape(b, -1, c) + else: + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) + ) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNet(nn.Module): + """ + The full UNet model with attention and embedding. + :param in_channel: channels in the input Tensor, for image colorization : Y_channels + X_channels . + :param inner_channel: base channel count for the model. + :param out_channel: channels in the output Tensor. + :param res_blocks: number of residual blocks per downsample. + :param attn_res: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mults: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channel, + inner_channel, + out_channel, + res_blocks, + attn_res, + tanh, + n_timestep_train, + n_timestep_test, + norm, + group_norm_size, + cond_embed_dim, + dropout=0, + channel_mults=(1, 2, 4, 8), + conv_resample=True, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=True, + resblock_updown=True, + use_new_attention_order=False, + efficient=False, + freq_space=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channel = in_channel + self.inner_channel = inner_channel + self.out_channel = out_channel + self.res_blocks = res_blocks + self.attn_res = attn_res + self.dropout = dropout + self.channel_mults = channel_mults + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + in_channel *= 4 + out_channel *= 4 + + if norm == "groupnorm": + norm = norm + str(group_norm_size) + + self.cond_embed_dim = cond_embed_dim + + ch = input_ch = int(channel_mults[0] * self.inner_channel) + self.input_blocks = nn.ModuleList( + [EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mults): + for _ in range(res_blocks[level]): + layers = [ + ResBlock( + ch, + self.cond_embed_dim, + dropout, + out_channel=int(mult * self.inner_channel), + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + ] + ch = int(mult * self.inner_channel) + if ds in attn_res: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(EmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mults) - 1: + out_ch = ch + self.input_blocks.append( + EmbedSequential( + ResBlock( + ch, + self.cond_embed_dim, + dropout, + out_channel=out_ch, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = EmbedSequential( + ResBlock( + ch, + self.cond_embed_dim, + dropout, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + # AttentionBlock( + # ch, + # use_checkpoint=use_checkpoint, + # num_heads=num_heads, + # num_head_channels=num_head_channels, + # use_new_attention_order=use_new_attention_order, + # ), + ResBlock( + ch, + self.cond_embed_dim, + dropout, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + ) + self._feature_size += ch + self.bottleneck_conv = nn.Conv2d( + self.inner_channel * channel_mults[-1], + 2, + kernel_size=2, + stride=1, + padding=0, + bias=True, + ) + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mults))[::-1]: + for i in range(res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + self.cond_embed_dim, + dropout, + out_channel=int(self.inner_channel * mult), + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + ] + ch = int(self.inner_channel * mult) + if ds in attn_res: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + if level and i == res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + self.cond_embed_dim, + dropout, + out_channel=out_ch, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + if resblock_updown + else Upsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) + ) + ds //= 2 + self.output_blocks.append(EmbedSequential(*layers)) + self._feature_size += ch + + if tanh: + self.out = nn.Sequential( + normalization(ch, norm), + zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), + nn.Sigmoid(), + ) + else: + self.out = nn.Sequential( + normalization(ch, norm), + torch.nn.SiLU(), + zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), + ) + + self.beta_schedule = { + "train": { + "schedule": "linear", + "n_timestep": n_timestep_train, + "linear_start": 1e-6, + "linear_end": 0.01, + }, + "test": { + "schedule": "linear", + "n_timestep": n_timestep_test, + "linear_start": 1e-4, + "linear_end": 0.09, + }, + } + + def compute_feats(self, input, embed_gammas): + if embed_gammas is None: + # Only for GAN + b = (input.shape[0], self.cond_embed_dim) + embed_gammas = torch.ones(b).to(input.device) + + emb = embed_gammas + + hs = [] + + h = input.type(torch.float32) + + if self.freq_space: + h = self.dwt(h) + + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + + bottleneck_conv = self.bottleneck_conv(h) + + # outh_encoder = nn.Sigmoid()(bottleneck_conv) #nn.Tanh()(bottleneck_conv) + + outs, feats = h, hs + + return outs, feats, emb, bottleneck_conv # outh_encoder + + def forward(self, input, embed_gammas=None): + h, hs, emb, outh_encoder = self.compute_feats(input, embed_gammas=embed_gammas) + + for i, module in enumerate(self.output_blocks): + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb) + h = h.type(input.dtype) + outh = self.out(h) + + if self.freq_space: + outh = self.iwt(outh) + + return outh, outh_encoder + + def get_feats(self, input, extract_layer_ids): + _, hs, _ = self.compute_feats(input, embed_gammas=None) + feats = [] + + for i, feat in enumerate(hs): + if i in extract_layer_ids: + feats.append(feat) + + return feats + + def extract(self, a, t, x_shape=(1, 1, 1, 1)): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) diff --git a/options/base_options.py b/options/base_options.py index a72a20070..e143b2179 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -379,6 +379,20 @@ def initialize(self, parser): default=64, help="# of discrim filters in the first conv layer", ) + parser.add_argument( + "--D_ngf", + type=int, + default=64, + help="#*8 of discrim filters in the last conv layer", + ) + + parser.add_argument( + "--D_num_downs", + type=int, + default=7, + help="# of downsampling", + ) + parser.add_argument( "--D_netDs", type=str, @@ -393,6 +407,8 @@ def initialize(self, parser): "depth", "mask", "sam", + "unet", + "unet_discriminator_mha", ] + list(TORCH_MODEL_CLASSES.keys()), help="specify discriminator architecture, another option, --D_n_layers allows you to specify the layers in the n_layers discriminator. NB: duplicated arguments are ignored. Values: basic, n_layers, pixel, projected_d, temporal, vision_aided, depth, mask, sam", diff --git a/options/train_options.py b/options/train_options.py index 8ae849dfa..cfc6076e5 100644 --- a/options/train_options.py +++ b/options/train_options.py @@ -302,6 +302,14 @@ def initialize(self, parser): default=1, help="backward will be apllied each iter_size iterations, it simulate a greater batch size : its value is batch_size*iter_size", ) + + parser.add_argument( + "--train_use_cutmix", + type=bool, + default=False, + help="add cutmix augmentation", + ) + parser.add_argument("--train_use_contrastive_loss_D", action="store_true") # frequency space training diff --git a/tests/test_run_nosemantic.py b/tests/test_run_nosemantic.py index e15c1edc9..3f6cbae89 100644 --- a/tests/test_run_nosemantic.py +++ b/tests/test_run_nosemantic.py @@ -32,7 +32,11 @@ "cycle_gan", ] -D_netDs = [["projected_d", "basic"], ["projected_d", "basic", "depth"]] +D_netDs = [ + ["projected_d", "basic"], + ["projected_d", "basic", "depth"], + ["projected_d", "basic", "unet"], +] train_feat_wavelet = [False, True] diff --git a/util/cutmix.py b/util/cutmix.py new file mode 100644 index 000000000..d2c961c29 --- /dev/null +++ b/util/cutmix.py @@ -0,0 +1,30 @@ +import torch +import numpy as np + +# Define the CutMix function and its supporting function +# CutMix from https://github.com/boschresearch/unetgan/blob/master/mixup.py + + +def random_boundingbox(size, lam): + width, height = size, size + r = np.sqrt(1.0 - lam) + w = int(width * r) # Modified this line + h = int(height * r) # And this line + x = np.random.randint(width) + y = np.random.randint(height) + x1 = np.clip(x - w // 2, 0, width) + y1 = np.clip(y - h // 2, 0, height) + x2 = np.clip(x + w // 2, 0, width) + y2 = np.clip(y + h // 2, 0, height) + return x1, y1, x2, y2 + + +def CutMix(imsize): + lam = np.random.beta(1, 1) + x1, y1, x2, y2 = random_boundingbox(imsize, lam) + lam = 1 - ((x2 - x1) * (y2 - y1) / (imsize * imsize)) + mask = torch.ones((imsize, imsize)) + mask[x1:x2, y1:y2] = 0 + if torch.rand(1) > 0.5: + mask = 1 - mask + return mask diff --git a/util/parser.py b/util/parser.py index 43295ac72..10d8ffaec 100644 --- a/util/parser.py +++ b/util/parser.py @@ -31,7 +31,7 @@ def get_opt(main_opt, remaining_args): for name in override_options_names: train_json[name] = override_options_json[name] - opt = TrainOptions().parse_json(train_json) + opt = TrainOptions().parse_json(train_json, save_config=True) print("%s config file loaded" % main_opt.config_json) else: