diff --git a/models/pggan_generator.py b/models/pggan_generator.py index b22b21a..3c8390b 100644 --- a/models/pggan_generator.py +++ b/models/pggan_generator.py @@ -25,14 +25,17 @@ def __init__(self, model_name, logger=None): assert self.gan_type == 'pggan' def build(self): - self.check_attr('final_tanh') + self.check_attr('fused_scale') self.model = PGGANGeneratorModel(resolution=self.resolution, - final_tanh=self.final_tanh) + fused_scale=self.fused_scale, + output_channels=self.output_channels) def load(self): self.logger.info(f'Loading pytorch model from `{self.model_path}`.') self.model.load_state_dict(torch.load(self.model_path)) self.logger.info(f'Successfully loaded!') + self.lod = self.model.lod.to(self.cpu_device).tolist() + self.logger.info(f' `lod` of the loaded model is {self.lod}.') def convert_tf_model(self, test_num=10): import sys @@ -52,12 +55,16 @@ def convert_tf_model(self, test_num=10): state_dict = self.model.state_dict() for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items(): if tf_var_name not in tf_vars: + self.logger.debug(f'Variable `{tf_var_name}` does not exist in ' + f'tensorflow model.') continue self.logger.debug(f' Converting `{tf_var_name}` to `{pth_var_name}`.') - var = torch.from_numpy(tf_vars[tf_var_name]) + var = torch.from_numpy(np.array(tf_vars[tf_var_name])) if 'weight' in pth_var_name: - if 'layer1.conv' in pth_var_name: + if 'layer0.conv' in pth_var_name: var = var.view(var.shape[0], -1, 4, 4).permute(1, 0, 2, 3).flip(2, 3) + elif 'Conv0_up' in tf_var_name: + var = var.permute(0, 1, 3, 2) else: var = var.permute(3, 2, 0, 1) state_dict[pth_var_name] = var diff --git a/models/pggan_generator_model.py b/models/pggan_generator_model.py index 26384ca..146f19f 100644 --- a/models/pggan_generator_model.py +++ b/models/pggan_generator_model.py @@ -9,11 +9,11 @@ https://arxiv.org/pdf/1710.10196.pdf """ -from collections import OrderedDict import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F __all__ = ['PGGANGeneratorModel'] @@ -32,99 +32,152 @@ # Variable mapping from pytorch model to official tensorflow model. _PGGAN_PTH_VARS_TO_TF_VARS = { - 'layer1.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4] - 'layer1.wscale.bias': '4x4/Dense/bias', # [512] - 'layer2.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] - 'layer2.wscale.bias': '4x4/Conv/bias', # [512] - 'layer3.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3] - 'layer3.wscale.bias': '8x8/Conv0/bias', # [512] - 'layer4.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] - 'layer4.wscale.bias': '8x8/Conv1/bias', # [512] - 'layer5.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3] - 'layer5.wscale.bias': '16x16/Conv0/bias', # [512] - 'layer6.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] - 'layer6.wscale.bias': '16x16/Conv1/bias', # [512] - 'layer7.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3] - 'layer7.wscale.bias': '32x32/Conv0/bias', # [512] - 'layer8.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] - 'layer8.wscale.bias': '32x32/Conv1/bias', # [512] - 'layer9.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3] - 'layer9.wscale.bias': '64x64/Conv0/bias', # [256] - 'layer10.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] - 'layer10.wscale.bias': '64x64/Conv1/bias', # [256] - 'layer11.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3] - 'layer11.wscale.bias': '128x128/Conv0/bias', # [128] - 'layer12.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] - 'layer12.wscale.bias': '128x128/Conv1/bias', # [128] - 'layer13.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3] - 'layer13.wscale.bias': '256x256/Conv0/bias', # [64] - 'layer14.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] - 'layer14.wscale.bias': '256x256/Conv1/bias', # [64] - 'layer15.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3] - 'layer15.wscale.bias': '512x512/Conv0/bias', # [32] - 'layer16.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] - 'layer16.wscale.bias': '512x512/Conv1/bias', # [32] - 'layer17.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3] - 'layer17.wscale.bias': '1024x1024/Conv0/bias', # [16] - 'layer18.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] - 'layer18.wscale.bias': '1024x1024/Conv1/bias', # [16] - 'output_1024x1024.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] - 'output_1024x1024.wscale.bias': 'ToRGB_lod0/bias', # [3] + 'lod': 'lod', # [] + 'layer0.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4] + 'layer0.wscale.bias': '4x4/Dense/bias', # [512] + 'layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] + 'layer1.wscale.bias': '4x4/Conv/bias', # [512] + 'layer2.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3] + 'layer2.wscale.bias': '8x8/Conv0/bias', # [512] + 'layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] + 'layer3.wscale.bias': '8x8/Conv1/bias', # [512] + 'layer4.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3] + 'layer4.wscale.bias': '16x16/Conv0/bias', # [512] + 'layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] + 'layer5.wscale.bias': '16x16/Conv1/bias', # [512] + 'layer6.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3] + 'layer6.wscale.bias': '32x32/Conv0/bias', # [512] + 'layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] + 'layer7.wscale.bias': '32x32/Conv1/bias', # [512] + 'layer8.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3] + 'layer8.wscale.bias': '64x64/Conv0/bias', # [256] + 'layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] + 'layer9.wscale.bias': '64x64/Conv1/bias', # [256] + 'layer10.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3] + 'layer10.wscale.bias': '128x128/Conv0/bias', # [128] + 'layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] + 'layer11.wscale.bias': '128x128/Conv1/bias', # [128] + 'layer12.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3] + 'layer12.wscale.bias': '256x256/Conv0/bias', # [64] + 'layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] + 'layer13.wscale.bias': '256x256/Conv1/bias', # [64] + 'layer14.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3] + 'layer14.wscale.bias': '512x512/Conv0/bias', # [32] + 'layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] + 'layer15.wscale.bias': '512x512/Conv1/bias', # [32] + 'layer16.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3] + 'layer16.wscale.bias': '1024x1024/Conv0/bias', # [16] + 'layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] + 'layer17.wscale.bias': '1024x1024/Conv1/bias', # [16] + 'output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] + 'output0.wscale.bias': 'ToRGB_lod8/bias', # [3] + 'output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] + 'output1.wscale.bias': 'ToRGB_lod7/bias', # [3] + 'output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] + 'output2.wscale.bias': 'ToRGB_lod6/bias', # [3] + 'output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] + 'output3.wscale.bias': 'ToRGB_lod5/bias', # [3] + 'output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] + 'output4.wscale.bias': 'ToRGB_lod4/bias', # [3] + 'output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] + 'output5.wscale.bias': 'ToRGB_lod3/bias', # [3] + 'output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] + 'output6.wscale.bias': 'ToRGB_lod2/bias', # [3] + 'output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] + 'output7.wscale.bias': 'ToRGB_lod1/bias', # [3] + 'output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] + 'output8.wscale.bias': 'ToRGB_lod0/bias', # [3] } -class PGGANGeneratorModel(nn.Sequential): +class PGGANGeneratorModel(nn.Module): """Defines the generator module in ProgressiveGAN. Note that the generated images are with RGB color channels with range [-1, 1]. """ - def __init__(self, resolution=1024, final_tanh=False): + def __init__(self, + resolution=1024, + fused_scale=False, + output_channels=3): """Initializes the generator with basic settings. Args: resolution: The resolution of the final output image. (default: 1024) - final_tanh: Whether to use a `tanh` function to clamp the pixel values of - the output image to range [-1, 1]. (default: False) + fused_scale: Whether to fused `upsample` and `conv2d` together, resulting + in `conv2_transpose`. (default: False) + output_channels: Number of channels of the output image. (default: 3) Raises: ValueError: If the input `resolution` is not supported. """ + super().__init__() + try: - channels = _RESOLUTIONS_TO_CHANNELS[resolution] + self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] except KeyError: raise ValueError(f'Invalid resolution: {resolution}!\n' f'Resolutions allowed: ' f'{list(_RESOLUTIONS_TO_CHANNELS)}.') - - sequence = OrderedDict() - - def _add_layer(layer, name=None): - name = name or f'layer{len(sequence) + 1}' - sequence[name] = layer - - _add_layer(ConvBlock(channels[0], channels[1], kernel_size=4, padding=3)) - _add_layer(ConvBlock(channels[1], channels[1])) - for i in range(2, len(channels)): - _add_layer(ConvBlock(channels[i-1], channels[i], upsample=True)) - _add_layer(ConvBlock(channels[i], channels[i])) - # Final convolutional block. - _add_layer(ConvBlock(in_channels=channels[-1], - out_channels=3, - kernel_size=1, - padding=0, - wscale_gain=1.0, - activation_type='tanh' if final_tanh else 'linear'), - name=f'output_{resolution}x{resolution}') - super().__init__(sequence) - self.pth_to_tf_var_mapping = _PGGAN_PTH_VARS_TO_TF_VARS + assert len(self.channels) == int(np.log2(resolution)) + + self.resolution = resolution + self.fused_scale = fused_scale + self.output_channels = output_channels + + for block_idx in range(1, len(self.channels)): + if block_idx == 1: + self.add_module( + f'layer{2 * block_idx - 2}', + ConvBlock(in_channels=self.channels[block_idx - 1], + out_channels=self.channels[block_idx], + kernel_size=4, + padding=3)) + else: + self.add_module( + f'layer{2 * block_idx - 2}', + ConvBlock(in_channels=self.channels[block_idx - 1], + out_channels=self.channels[block_idx], + upsample=True, + fused_scale=self.fused_scale)) + self.add_module( + f'layer{2 * block_idx - 1}', + ConvBlock(in_channels=self.channels[block_idx], + out_channels=self.channels[block_idx])) + self.add_module( + f'output{block_idx - 1}', + ConvBlock(in_channels=self.channels[block_idx], + out_channels=self.output_channels, + kernel_size=1, + padding=0, + wscale_gain=1.0, + activation_type='linear')) + + self.upsample = ResolutionScalingLayer() + self.lod = nn.Parameter(torch.zeros(())) + + self.pth_to_tf_var_mapping = {} + for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items(): + if self.fused_scale and 'Conv0' in tf_var_name: + pth_var_name = pth_var_name.replace('conv.weight', 'weight') + tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up') + self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name def forward(self, x): if len(x.shape) != 2: raise ValueError(f'The input tensor should be with shape [batch_size, ' f'noise_dim], but {x.shape} received!') x = x.view(x.shape[0], x.shape[1], 1, 1) - return super().forward(x) + + lod = self.lod.cpu().tolist() + for block_idx in range(1, len(self.channels)): + if block_idx + lod < len(self.channels): + x = self.__getattr__(f'layer{2 * block_idx - 2}')(x) + x = self.__getattr__(f'layer{2 * block_idx - 1}')(x) + image = self.__getattr__(f'output{block_idx - 1}')(x) + else: + image = self.upsample(image) + return image class PixelNormLayer(nn.Module): @@ -150,9 +203,7 @@ def __init__(self, scale_factor=2): self.scale_factor = scale_factor def forward(self, x): - return nn.functional.interpolate(x, - scale_factor=self.scale_factor, - mode='nearest') + return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest') class WScaleLayer(nn.Module): @@ -190,6 +241,7 @@ def __init__(self, dilation=1, add_bias=False, upsample=False, + fused_scale=False, wscale_gain=np.sqrt(2.0), activation_type='lrelu'): """Initializes the class with block settings. @@ -203,6 +255,8 @@ def __init__(self, dilation: Dilation rate for convolution operation. add_bias: Whether to add bias onto the convolutional result. upsample: Whether to upsample the input tensor before convolution. + fused_scale: Whether to fused `upsample` and `conv2d` together, resulting + in `conv2_transpose`. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. @@ -214,19 +268,32 @@ def __init__(self, """ super().__init__() self.pixel_norm = PixelNormLayer() - self.upsample = ResolutionScalingLayer() if upsample else (lambda x: x) - self.conv = nn.Conv2d(in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - groups=1, - bias=add_bias) + + if upsample and not fused_scale: + self.upsample = ResolutionScalingLayer() + else: + self.upsample = lambda x: x + + if upsample and fused_scale: + self.weight = nn.Parameter( + torch.randn(kernel_size, kernel_size, in_channels, out_channels)) + fan_in = in_channels * kernel_size * kernel_size + self.scale = wscale_gain / np.sqrt(fan_in) + else: + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=1, + bias=add_bias) + self.wscale = WScaleLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, gain=wscale_gain) + if activation_type == 'linear': self.activate = (lambda x: x) elif activation_type == 'lrelu': @@ -240,7 +307,16 @@ def __init__(self, def forward(self, x): x = self.pixel_norm(x) x = self.upsample(x) - x = self.conv(x) + if hasattr(self, 'conv'): + x = self.conv(x) + else: + kernel = self.weight * self.scale + kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) + kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + + kernel[1:, :-1] + kernel[:-1, :-1]) + kernel = kernel.permute(2, 3, 0, 1) + x = F.conv_transpose2d(x, kernel, stride=2, padding=1) + x = x / self.scale x = self.wscale(x) x = self.activate(x) return x diff --git a/models/stylegan_generator.py b/models/stylegan_generator.py index 64ae423..c1683a4 100644 --- a/models/stylegan_generator.py +++ b/models/stylegan_generator.py @@ -45,9 +45,12 @@ def __init__(self, model_name, logger=None): def build(self): self.check_attr('w_space_dim') + self.check_attr('fused_scale') self.model = StyleGANGeneratorModel( resolution=self.resolution, w_space_dim=self.w_space_dim, + fused_scale=self.fused_scale, + output_channels=self.output_channels, truncation_psi=self.truncation_psi, truncation_layers=self.truncation_layers, randomize_noise=self.randomize_noise) @@ -59,6 +62,8 @@ def load(self): state_dict[var_name] = self.model.state_dict()[var_name] self.model.load_state_dict(state_dict) self.logger.info(f'Successfully loaded!') + self.lod = self.model.synthesis.lod.to(self.cpu_device).tolist() + self.logger.info(f' `lod` of the loaded model is {self.lod}.') def convert_tf_model(self, test_num=10): import sys @@ -82,9 +87,11 @@ def convert_tf_model(self, test_num=10): state_dict = self.model.state_dict() for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items(): if tf_var_name not in tf_vars: + self.logger.debug(f'Variable `{tf_var_name}` does not exist in ' + f'tensorflow model.') continue self.logger.debug(f' Converting `{tf_var_name}` to `{pth_var_name}`.') - var = torch.from_numpy(tf_vars[tf_var_name]) + var = torch.from_numpy(np.array(tf_vars[tf_var_name])) if 'weight' in pth_var_name: if 'dense' in pth_var_name: var = var.permute(1, 0) diff --git a/models/stylegan_generator_model.py b/models/stylegan_generator_model.py index f4eaecf..f4771db 100644 --- a/models/stylegan_generator_model.py +++ b/models/stylegan_generator_model.py @@ -31,270 +31,165 @@ class is specially used for inference. 1024: [512, 512, 512, 512, 512, 256, 128, 64, 32, 16], } +# pylint: disable=line-too-long # Variable mapping from pytorch model to official tensorflow model. _STYLEGAN_PTH_VARS_TO_TF_VARS = { # Statistic information of disentangled latent feature, w. - 'truncation.w_avg': - 'dlatent_avg', # [512] + 'truncation.w_avg':'dlatent_avg', # [512] # Noises. - 'synthesis.layer0.epilogue.apply_noise.noise': - 'noise0', # [1, 1, 4, 4] - 'synthesis.layer1.epilogue.apply_noise.noise': - 'noise1', # [1, 1, 4, 4] - 'synthesis.layer2.epilogue.apply_noise.noise': - 'noise2', # [1, 1, 8, 8] - 'synthesis.layer3.epilogue.apply_noise.noise': - 'noise3', # [1, 1, 8, 8] - 'synthesis.layer4.epilogue.apply_noise.noise': - 'noise4', # [1, 1, 16, 16] - 'synthesis.layer5.epilogue.apply_noise.noise': - 'noise5', # [1, 1, 16, 16] - 'synthesis.layer6.epilogue.apply_noise.noise': - 'noise6', # [1, 1, 32, 32] - 'synthesis.layer7.epilogue.apply_noise.noise': - 'noise7', # [1, 1, 32, 32] - 'synthesis.layer8.epilogue.apply_noise.noise': - 'noise8', # [1, 1, 64, 64] - 'synthesis.layer9.epilogue.apply_noise.noise': - 'noise9', # [1, 1, 64, 64] - 'synthesis.layer10.epilogue.apply_noise.noise': - 'noise10', # [1, 1, 128, 128] - 'synthesis.layer11.epilogue.apply_noise.noise': - 'noise11', # [1, 1, 128, 128] - 'synthesis.layer12.epilogue.apply_noise.noise': - 'noise12', # [1, 1, 256, 256] - 'synthesis.layer13.epilogue.apply_noise.noise': - 'noise13', # [1, 1, 256, 256] - 'synthesis.layer14.epilogue.apply_noise.noise': - 'noise14', # [1, 1, 512, 512] - 'synthesis.layer15.epilogue.apply_noise.noise': - 'noise15', # [1, 1, 512, 512] - 'synthesis.layer16.epilogue.apply_noise.noise': - 'noise16', # [1, 1, 1024, 1024] - 'synthesis.layer17.epilogue.apply_noise.noise': - 'noise17', # [1, 1, 1024, 1024] + 'synthesis.layer0.epilogue.apply_noise.noise': 'noise0', # [1, 1, 4, 4] + 'synthesis.layer1.epilogue.apply_noise.noise': 'noise1', # [1, 1, 4, 4] + 'synthesis.layer2.epilogue.apply_noise.noise': 'noise2', # [1, 1, 8, 8] + 'synthesis.layer3.epilogue.apply_noise.noise': 'noise3', # [1, 1, 8, 8] + 'synthesis.layer4.epilogue.apply_noise.noise': 'noise4', # [1, 1, 16, 16] + 'synthesis.layer5.epilogue.apply_noise.noise': 'noise5', # [1, 1, 16, 16] + 'synthesis.layer6.epilogue.apply_noise.noise': 'noise6', # [1, 1, 32, 32] + 'synthesis.layer7.epilogue.apply_noise.noise': 'noise7', # [1, 1, 32, 32] + 'synthesis.layer8.epilogue.apply_noise.noise': 'noise8', # [1, 1, 64, 64] + 'synthesis.layer9.epilogue.apply_noise.noise': 'noise9', # [1, 1, 64, 64] + 'synthesis.layer10.epilogue.apply_noise.noise': 'noise10', # [1, 1, 128, 128] + 'synthesis.layer11.epilogue.apply_noise.noise': 'noise11', # [1, 1, 128, 128] + 'synthesis.layer12.epilogue.apply_noise.noise': 'noise12', # [1, 1, 256, 256] + 'synthesis.layer13.epilogue.apply_noise.noise': 'noise13', # [1, 1, 256, 256] + 'synthesis.layer14.epilogue.apply_noise.noise': 'noise14', # [1, 1, 512, 512] + 'synthesis.layer15.epilogue.apply_noise.noise': 'noise15', # [1, 1, 512, 512] + 'synthesis.layer16.epilogue.apply_noise.noise': 'noise16', # [1, 1, 1024, 1024] + 'synthesis.layer17.epilogue.apply_noise.noise': 'noise17', # [1, 1, 1024, 1024] # Mapping blocks. - 'mapping.dense0.linear.weight': - 'Dense0/weight', # [512, 512] - 'mapping.dense0.wscale.bias': - 'Dense0/bias', # [512] - 'mapping.dense1.linear.weight': - 'Dense1/weight', # [512, 512] - 'mapping.dense1.wscale.bias': - 'Dense1/bias', # [512] - 'mapping.dense2.linear.weight': - 'Dense2/weight', # [512, 512] - 'mapping.dense2.wscale.bias': - 'Dense2/bias', # [512] - 'mapping.dense3.linear.weight': - 'Dense3/weight', # [512, 512] - 'mapping.dense3.wscale.bias': - 'Dense3/bias', # [512] - 'mapping.dense4.linear.weight': - 'Dense4/weight', # [512, 512] - 'mapping.dense4.wscale.bias': - 'Dense4/bias', # [512] - 'mapping.dense5.linear.weight': - 'Dense5/weight', # [512, 512] - 'mapping.dense5.wscale.bias': - 'Dense5/bias', # [512] - 'mapping.dense6.linear.weight': - 'Dense6/weight', # [512, 512] - 'mapping.dense6.wscale.bias': - 'Dense6/bias', # [512] - 'mapping.dense7.linear.weight': - 'Dense7/weight', # [512, 512] - 'mapping.dense7.wscale.bias': - 'Dense7/bias', # [512] + 'mapping.dense0.linear.weight': 'Dense0/weight', # [512, 512] + 'mapping.dense0.wscale.bias': 'Dense0/bias', # [512] + 'mapping.dense1.linear.weight': 'Dense1/weight', # [512, 512] + 'mapping.dense1.wscale.bias': 'Dense1/bias', # [512] + 'mapping.dense2.linear.weight': 'Dense2/weight', # [512, 512] + 'mapping.dense2.wscale.bias': 'Dense2/bias', # [512] + 'mapping.dense3.linear.weight': 'Dense3/weight', # [512, 512] + 'mapping.dense3.wscale.bias': 'Dense3/bias', # [512] + 'mapping.dense4.linear.weight': 'Dense4/weight', # [512, 512] + 'mapping.dense4.wscale.bias': 'Dense4/bias', # [512] + 'mapping.dense5.linear.weight': 'Dense5/weight', # [512, 512] + 'mapping.dense5.wscale.bias': 'Dense5/bias', # [512] + 'mapping.dense6.linear.weight': 'Dense6/weight', # [512, 512] + 'mapping.dense6.wscale.bias': 'Dense6/bias', # [512] + 'mapping.dense7.linear.weight': 'Dense7/weight', # [512, 512] + 'mapping.dense7.wscale.bias': 'Dense7/bias', # [512] # Synthesis blocks. - 'synthesis.layer0.first_layer': - '4x4/Const/const', # [1, 512, 4, 4] - 'synthesis.layer0.epilogue.apply_noise.weight': - '4x4/Const/Noise/weight', # [512] - 'synthesis.layer0.epilogue.bias': - '4x4/Const/bias', # [512] - 'synthesis.layer0.epilogue.style_mod.dense.linear.weight': - '4x4/Const/StyleMod/weight', # [1024, 512] - 'synthesis.layer0.epilogue.style_mod.dense.wscale.bias': - '4x4/Const/StyleMod/bias', # [1024] - 'synthesis.layer1.conv.weight': - '4x4/Conv/weight', # [512, 512, 3, 3] - 'synthesis.layer1.epilogue.apply_noise.weight': - '4x4/Conv/Noise/weight', # [512] - 'synthesis.layer1.epilogue.bias': - '4x4/Conv/bias', # [512] - 'synthesis.layer1.epilogue.style_mod.dense.linear.weight': - '4x4/Conv/StyleMod/weight', # [1024, 512] - 'synthesis.layer1.epilogue.style_mod.dense.wscale.bias': - '4x4/Conv/StyleMod/bias', # [1024] - 'synthesis.layer2.conv.weight': - '8x8/Conv0_up/weight', # [512, 512, 3, 3] - 'synthesis.layer2.epilogue.apply_noise.weight': - '8x8/Conv0_up/Noise/weight', # [512] - 'synthesis.layer2.epilogue.bias': - '8x8/Conv0_up/bias', # [512] - 'synthesis.layer2.epilogue.style_mod.dense.linear.weight': - '8x8/Conv0_up/StyleMod/weight', # [1024, 512] - 'synthesis.layer2.epilogue.style_mod.dense.wscale.bias': - '8x8/Conv0_up/StyleMod/bias', # [1024] - 'synthesis.layer3.conv.weight': - '8x8/Conv1/weight', # [512, 512, 3, 3] - 'synthesis.layer3.epilogue.apply_noise.weight': - '8x8/Conv1/Noise/weight', # [512] - 'synthesis.layer3.epilogue.bias': - '8x8/Conv1/bias', # [512] - 'synthesis.layer3.epilogue.style_mod.dense.linear.weight': - '8x8/Conv1/StyleMod/weight', # [1024, 512] - 'synthesis.layer3.epilogue.style_mod.dense.wscale.bias': - '8x8/Conv1/StyleMod/bias', # [1024] - 'synthesis.layer4.conv.weight': - '16x16/Conv0_up/weight', # [512, 512, 3, 3] - 'synthesis.layer4.epilogue.apply_noise.weight': - '16x16/Conv0_up/Noise/weight', # [512] - 'synthesis.layer4.epilogue.bias': - '16x16/Conv0_up/bias', # [512] - 'synthesis.layer4.epilogue.style_mod.dense.linear.weight': - '16x16/Conv0_up/StyleMod/weight', # [1024, 512] - 'synthesis.layer4.epilogue.style_mod.dense.wscale.bias': - '16x16/Conv0_up/StyleMod/bias', # [1024] - 'synthesis.layer5.conv.weight': - '16x16/Conv1/weight', # [512, 512, 3, 3] - 'synthesis.layer5.epilogue.apply_noise.weight': - '16x16/Conv1/Noise/weight', # [512] - 'synthesis.layer5.epilogue.bias': - '16x16/Conv1/bias', # [512] - 'synthesis.layer5.epilogue.style_mod.dense.linear.weight': - '16x16/Conv1/StyleMod/weight', # [1024, 512] - 'synthesis.layer5.epilogue.style_mod.dense.wscale.bias': - '16x16/Conv1/StyleMod/bias', # [1024] - 'synthesis.layer6.conv.weight': - '32x32/Conv0_up/weight', # [512, 512, 3, 3] - 'synthesis.layer6.epilogue.apply_noise.weight': - '32x32/Conv0_up/Noise/weight', # [512] - 'synthesis.layer6.epilogue.bias': - '32x32/Conv0_up/bias', # [512] - 'synthesis.layer6.epilogue.style_mod.dense.linear.weight': - '32x32/Conv0_up/StyleMod/weight', # [1024, 512] - 'synthesis.layer6.epilogue.style_mod.dense.wscale.bias': - '32x32/Conv0_up/StyleMod/bias', # [1024] - 'synthesis.layer7.conv.weight': - '32x32/Conv1/weight', # [512, 512, 3, 3] - 'synthesis.layer7.epilogue.apply_noise.weight': - '32x32/Conv1/Noise/weight', # [512] - 'synthesis.layer7.epilogue.bias': - '32x32/Conv1/bias', # [512] - 'synthesis.layer7.epilogue.style_mod.dense.linear.weight': - '32x32/Conv1/StyleMod/weight', # [1024, 512] - 'synthesis.layer7.epilogue.style_mod.dense.wscale.bias': - '32x32/Conv1/StyleMod/bias', # [1024] - 'synthesis.layer8.conv.weight': - '64x64/Conv0_up/weight', # [256, 512, 3, 3] - 'synthesis.layer8.epilogue.apply_noise.weight': - '64x64/Conv0_up/Noise/weight', # [256] - 'synthesis.layer8.epilogue.bias': - '64x64/Conv0_up/bias', # [256] - 'synthesis.layer8.epilogue.style_mod.dense.linear.weight': - '64x64/Conv0_up/StyleMod/weight', # [512, 512] - 'synthesis.layer8.epilogue.style_mod.dense.wscale.bias': - '64x64/Conv0_up/StyleMod/bias', # [512] - 'synthesis.layer9.conv.weight': - '64x64/Conv1/weight', # [256, 256, 3, 3] - 'synthesis.layer9.epilogue.apply_noise.weight': - '64x64/Conv1/Noise/weight', # [256] - 'synthesis.layer9.epilogue.bias': - '64x64/Conv1/bias', # [256] - 'synthesis.layer9.epilogue.style_mod.dense.linear.weight': - '64x64/Conv1/StyleMod/weight', # [512, 512] - 'synthesis.layer9.epilogue.style_mod.dense.wscale.bias': - '64x64/Conv1/StyleMod/bias', # [512] - 'synthesis.layer10.weight': - '128x128/Conv0_up/weight', # [3, 3, 256, 128] - 'synthesis.layer10.epilogue.apply_noise.weight': - '128x128/Conv0_up/Noise/weight', # [128] - 'synthesis.layer10.epilogue.bias': - '128x128/Conv0_up/bias', # [128] - 'synthesis.layer10.epilogue.style_mod.dense.linear.weight': - '128x128/Conv0_up/StyleMod/weight', # [256, 512] - 'synthesis.layer10.epilogue.style_mod.dense.wscale.bias': - '128x128/Conv0_up/StyleMod/bias', # [256] - 'synthesis.layer11.conv.weight': - '128x128/Conv1/weight', # [128, 128, 3, 3] - 'synthesis.layer11.epilogue.apply_noise.weight': - '128x128/Conv1/Noise/weight', # [128] - 'synthesis.layer11.epilogue.bias': - '128x128/Conv1/bias', # [128] - 'synthesis.layer11.epilogue.style_mod.dense.linear.weight': - '128x128/Conv1/StyleMod/weight', # [256, 512] - 'synthesis.layer11.epilogue.style_mod.dense.wscale.bias': - '128x128/Conv1/StyleMod/bias', # [256] - 'synthesis.layer12.weight': - '256x256/Conv0_up/weight', # [3, 3, 128, 64] - 'synthesis.layer12.epilogue.apply_noise.weight': - '256x256/Conv0_up/Noise/weight', # [64] - 'synthesis.layer12.epilogue.bias': - '256x256/Conv0_up/bias', # [64] - 'synthesis.layer12.epilogue.style_mod.dense.linear.weight': - '256x256/Conv0_up/StyleMod/weight', # [128, 512] - 'synthesis.layer12.epilogue.style_mod.dense.wscale.bias': - '256x256/Conv0_up/StyleMod/bias', # [128] - 'synthesis.layer13.conv.weight': - '256x256/Conv1/weight', # [64, 64, 3, 3] - 'synthesis.layer13.epilogue.apply_noise.weight': - '256x256/Conv1/Noise/weight', # [64] - 'synthesis.layer13.epilogue.bias': - '256x256/Conv1/bias', # [64] - 'synthesis.layer13.epilogue.style_mod.dense.linear.weight': - '256x256/Conv1/StyleMod/weight', # [128, 512] - 'synthesis.layer13.epilogue.style_mod.dense.wscale.bias': - '256x256/Conv1/StyleMod/bias', # [128] - 'synthesis.layer14.weight': - '512x512/Conv0_up/weight', # [3, 3, 64, 32] - 'synthesis.layer14.epilogue.apply_noise.weight': - '512x512/Conv0_up/Noise/weight', # [32] - 'synthesis.layer14.epilogue.bias': - '512x512/Conv0_up/bias', # [32] - 'synthesis.layer14.epilogue.style_mod.dense.linear.weight': - '512x512/Conv0_up/StyleMod/weight', # [64, 512] - 'synthesis.layer14.epilogue.style_mod.dense.wscale.bias': - '512x512/Conv0_up/StyleMod/bias', # [64] - 'synthesis.layer15.conv.weight': - '512x512/Conv1/weight', # [32, 32, 3, 3] - 'synthesis.layer15.epilogue.apply_noise.weight': - '512x512/Conv1/Noise/weight', # [32] - 'synthesis.layer15.epilogue.bias': - '512x512/Conv1/bias', # [32] - 'synthesis.layer15.epilogue.style_mod.dense.linear.weight': - '512x512/Conv1/StyleMod/weight', # [64, 512] - 'synthesis.layer15.epilogue.style_mod.dense.wscale.bias': - '512x512/Conv1/StyleMod/bias', # [64] - 'synthesis.layer16.weight': - '1024x1024/Conv0_up/weight', # [3, 3, 32, 16] - 'synthesis.layer16.epilogue.apply_noise.weight': - '1024x1024/Conv0_up/Noise/weight', # [16] - 'synthesis.layer16.epilogue.bias': - '1024x1024/Conv0_up/bias', # [16] - 'synthesis.layer16.epilogue.style_mod.dense.linear.weight': - '1024x1024/Conv0_up/StyleMod/weight', # [32, 512] - 'synthesis.layer16.epilogue.style_mod.dense.wscale.bias': - '1024x1024/Conv0_up/StyleMod/bias', # [32] - 'synthesis.layer17.conv.weight': - '1024x1024/Conv1/weight', # [16, 16, 3, 3] - 'synthesis.layer17.epilogue.apply_noise.weight': - '1024x1024/Conv1/Noise/weight', # [16] - 'synthesis.layer17.epilogue.bias': - '1024x1024/Conv1/bias', # [16] - 'synthesis.layer17.epilogue.style_mod.dense.linear.weight': - '1024x1024/Conv1/StyleMod/weight', # [32, 512] - 'synthesis.layer17.epilogue.style_mod.dense.wscale.bias': - '1024x1024/Conv1/StyleMod/bias', # [32] - 'synthesis.output.conv.weight': - 'ToRGB_lod0/weight', # [3, 16, 1, 1] - 'synthesis.output.bias': - 'ToRGB_lod0/bias', # [3] + 'synthesis.lod': 'lod', # [] + 'synthesis.layer0.first_layer': '4x4/Const/const', # [1, 512, 4, 4] + 'synthesis.layer0.epilogue.apply_noise.weight': '4x4/Const/Noise/weight', # [512] + 'synthesis.layer0.epilogue.bias': '4x4/Const/bias', # [512] + 'synthesis.layer0.epilogue.style_mod.dense.linear.weight': '4x4/Const/StyleMod/weight', # [1024, 512] + 'synthesis.layer0.epilogue.style_mod.dense.wscale.bias': '4x4/Const/StyleMod/bias', # [1024] + 'synthesis.layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3] + 'synthesis.layer1.epilogue.apply_noise.weight': '4x4/Conv/Noise/weight', # [512] + 'synthesis.layer1.epilogue.bias': '4x4/Conv/bias', # [512] + 'synthesis.layer1.epilogue.style_mod.dense.linear.weight': '4x4/Conv/StyleMod/weight', # [1024, 512] + 'synthesis.layer1.epilogue.style_mod.dense.wscale.bias': '4x4/Conv/StyleMod/bias', # [1024] + 'synthesis.layer2.conv.weight': '8x8/Conv0_up/weight', # [512, 512, 3, 3] + 'synthesis.layer2.epilogue.apply_noise.weight': '8x8/Conv0_up/Noise/weight', # [512] + 'synthesis.layer2.epilogue.bias': '8x8/Conv0_up/bias', # [512] + 'synthesis.layer2.epilogue.style_mod.dense.linear.weight': '8x8/Conv0_up/StyleMod/weight', # [1024, 512] + 'synthesis.layer2.epilogue.style_mod.dense.wscale.bias': '8x8/Conv0_up/StyleMod/bias', # [1024] + 'synthesis.layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3] + 'synthesis.layer3.epilogue.apply_noise.weight': '8x8/Conv1/Noise/weight', # [512] + 'synthesis.layer3.epilogue.bias': '8x8/Conv1/bias', # [512] + 'synthesis.layer3.epilogue.style_mod.dense.linear.weight': '8x8/Conv1/StyleMod/weight', # [1024, 512] + 'synthesis.layer3.epilogue.style_mod.dense.wscale.bias': '8x8/Conv1/StyleMod/bias', # [1024] + 'synthesis.layer4.conv.weight': '16x16/Conv0_up/weight', # [512, 512, 3, 3] + 'synthesis.layer4.epilogue.apply_noise.weight': '16x16/Conv0_up/Noise/weight', # [512] + 'synthesis.layer4.epilogue.bias': '16x16/Conv0_up/bias', # [512] + 'synthesis.layer4.epilogue.style_mod.dense.linear.weight': '16x16/Conv0_up/StyleMod/weight', # [1024, 512] + 'synthesis.layer4.epilogue.style_mod.dense.wscale.bias': '16x16/Conv0_up/StyleMod/bias', # [1024] + 'synthesis.layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3] + 'synthesis.layer5.epilogue.apply_noise.weight': '16x16/Conv1/Noise/weight', # [512] + 'synthesis.layer5.epilogue.bias': '16x16/Conv1/bias', # [512] + 'synthesis.layer5.epilogue.style_mod.dense.linear.weight': '16x16/Conv1/StyleMod/weight', # [1024, 512] + 'synthesis.layer5.epilogue.style_mod.dense.wscale.bias': '16x16/Conv1/StyleMod/bias', # [1024] + 'synthesis.layer6.conv.weight': '32x32/Conv0_up/weight', # [512, 512, 3, 3] + 'synthesis.layer6.epilogue.apply_noise.weight': '32x32/Conv0_up/Noise/weight', # [512] + 'synthesis.layer6.epilogue.bias': '32x32/Conv0_up/bias', # [512] + 'synthesis.layer6.epilogue.style_mod.dense.linear.weight': '32x32/Conv0_up/StyleMod/weight', # [1024, 512] + 'synthesis.layer6.epilogue.style_mod.dense.wscale.bias': '32x32/Conv0_up/StyleMod/bias', # [1024] + 'synthesis.layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3] + 'synthesis.layer7.epilogue.apply_noise.weight': '32x32/Conv1/Noise/weight', # [512] + 'synthesis.layer7.epilogue.bias': '32x32/Conv1/bias', # [512] + 'synthesis.layer7.epilogue.style_mod.dense.linear.weight': '32x32/Conv1/StyleMod/weight', # [1024, 512] + 'synthesis.layer7.epilogue.style_mod.dense.wscale.bias': '32x32/Conv1/StyleMod/bias', # [1024] + 'synthesis.layer8.conv.weight': '64x64/Conv0_up/weight', # [256, 512, 3, 3] + 'synthesis.layer8.epilogue.apply_noise.weight': '64x64/Conv0_up/Noise/weight', # [256] + 'synthesis.layer8.epilogue.bias': '64x64/Conv0_up/bias', # [256] + 'synthesis.layer8.epilogue.style_mod.dense.linear.weight': '64x64/Conv0_up/StyleMod/weight', # [512, 512] + 'synthesis.layer8.epilogue.style_mod.dense.wscale.bias': '64x64/Conv0_up/StyleMod/bias', # [512] + 'synthesis.layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3] + 'synthesis.layer9.epilogue.apply_noise.weight': '64x64/Conv1/Noise/weight', # [256] + 'synthesis.layer9.epilogue.bias': '64x64/Conv1/bias', # [256] + 'synthesis.layer9.epilogue.style_mod.dense.linear.weight': '64x64/Conv1/StyleMod/weight', # [512, 512] + 'synthesis.layer9.epilogue.style_mod.dense.wscale.bias': '64x64/Conv1/StyleMod/bias', # [512] + 'synthesis.layer10.conv.weight': '128x128/Conv0_up/weight', # [128, 256, 3, 3] + 'synthesis.layer10.epilogue.apply_noise.weight': '128x128/Conv0_up/Noise/weight', # [128] + 'synthesis.layer10.epilogue.bias': '128x128/Conv0_up/bias', # [128] + 'synthesis.layer10.epilogue.style_mod.dense.linear.weight': '128x128/Conv0_up/StyleMod/weight', # [256, 512] + 'synthesis.layer10.epilogue.style_mod.dense.wscale.bias': '128x128/Conv0_up/StyleMod/bias', # [256] + 'synthesis.layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3] + 'synthesis.layer11.epilogue.apply_noise.weight': '128x128/Conv1/Noise/weight', # [128] + 'synthesis.layer11.epilogue.bias': '128x128/Conv1/bias', # [128] + 'synthesis.layer11.epilogue.style_mod.dense.linear.weight': '128x128/Conv1/StyleMod/weight', # [256, 512] + 'synthesis.layer11.epilogue.style_mod.dense.wscale.bias': '128x128/Conv1/StyleMod/bias', # [256] + 'synthesis.layer12.conv.weight': '256x256/Conv0_up/weight', # [64, 128, 3, 3] + 'synthesis.layer12.epilogue.apply_noise.weight': '256x256/Conv0_up/Noise/weight', # [64] + 'synthesis.layer12.epilogue.bias': '256x256/Conv0_up/bias', # [64] + 'synthesis.layer12.epilogue.style_mod.dense.linear.weight': '256x256/Conv0_up/StyleMod/weight', # [128, 512] + 'synthesis.layer12.epilogue.style_mod.dense.wscale.bias': '256x256/Conv0_up/StyleMod/bias', # [128] + 'synthesis.layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3] + 'synthesis.layer13.epilogue.apply_noise.weight': '256x256/Conv1/Noise/weight', # [64] + 'synthesis.layer13.epilogue.bias': '256x256/Conv1/bias', # [64] + 'synthesis.layer13.epilogue.style_mod.dense.linear.weight': '256x256/Conv1/StyleMod/weight', # [128, 512] + 'synthesis.layer13.epilogue.style_mod.dense.wscale.bias': '256x256/Conv1/StyleMod/bias', # [128] + 'synthesis.layer14.conv.weight': '512x512/Conv0_up/weight', # [32, 64, 3, 3] + 'synthesis.layer14.epilogue.apply_noise.weight': '512x512/Conv0_up/Noise/weight', # [32] + 'synthesis.layer14.epilogue.bias': '512x512/Conv0_up/bias', # [32] + 'synthesis.layer14.epilogue.style_mod.dense.linear.weight': '512x512/Conv0_up/StyleMod/weight', # [64, 512] + 'synthesis.layer14.epilogue.style_mod.dense.wscale.bias': '512x512/Conv0_up/StyleMod/bias', # [64] + 'synthesis.layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3] + 'synthesis.layer15.epilogue.apply_noise.weight': '512x512/Conv1/Noise/weight', # [32] + 'synthesis.layer15.epilogue.bias': '512x512/Conv1/bias', # [32] + 'synthesis.layer15.epilogue.style_mod.dense.linear.weight': '512x512/Conv1/StyleMod/weight', # [64, 512] + 'synthesis.layer15.epilogue.style_mod.dense.wscale.bias': '512x512/Conv1/StyleMod/bias', # [64] + 'synthesis.layer16.conv.weight': '1024x1024/Conv0_up/weight', # [16, 32, 3, 3] + 'synthesis.layer16.epilogue.apply_noise.weight': '1024x1024/Conv0_up/Noise/weight', # [16] + 'synthesis.layer16.epilogue.bias': '1024x1024/Conv0_up/bias', # [16] + 'synthesis.layer16.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv0_up/StyleMod/weight', # [32, 512] + 'synthesis.layer16.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv0_up/StyleMod/bias', # [32] + 'synthesis.layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3] + 'synthesis.layer17.epilogue.apply_noise.weight': '1024x1024/Conv1/Noise/weight', # [16] + 'synthesis.layer17.epilogue.bias': '1024x1024/Conv1/bias', # [16] + 'synthesis.layer17.epilogue.style_mod.dense.linear.weight': '1024x1024/Conv1/StyleMod/weight', # [32, 512] + 'synthesis.layer17.epilogue.style_mod.dense.wscale.bias': '1024x1024/Conv1/StyleMod/bias', # [32] + 'synthesis.output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1] + 'synthesis.output0.bias': 'ToRGB_lod8/bias', # [3] + 'synthesis.output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1] + 'synthesis.output1.bias': 'ToRGB_lod7/bias', # [3] + 'synthesis.output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1] + 'synthesis.output2.bias': 'ToRGB_lod6/bias', # [3] + 'synthesis.output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1] + 'synthesis.output3.bias': 'ToRGB_lod5/bias', # [3] + 'synthesis.output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1] + 'synthesis.output4.bias': 'ToRGB_lod4/bias', # [3] + 'synthesis.output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1] + 'synthesis.output5.bias': 'ToRGB_lod3/bias', # [3] + 'synthesis.output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1] + 'synthesis.output6.bias': 'ToRGB_lod2/bias', # [3] + 'synthesis.output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1] + 'synthesis.output7.bias': 'ToRGB_lod1/bias', # [3] + 'synthesis.output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1] + 'synthesis.output8.bias': 'ToRGB_lod0/bias', # [3] } +# pylint: enable=line-too-long + +# Minimal resolution for `auto` fused-scale strategy. +_AUTO_FUSED_SCALE_MIN_RES = 128 class StyleGANGeneratorModel(nn.Module): @@ -306,31 +201,59 @@ class StyleGANGeneratorModel(nn.Module): def __init__(self, resolution=1024, w_space_dim=512, + fused_scale='auto', + output_channels=3, truncation_psi=0.7, truncation_layers=8, randomize_noise=False): """Initializes the generator with basic settings. Args: - resolution: The resolution of the final output image. + resolution: The resolution of the final output image. (default: 1024) w_space_dim: The dimension of the disentangled latent vectors, w. + (default: 512) + fused_scale: If set as `True`, `conv2d_transpose` is used for upscaling. + If set as `False`, `upsample + conv2d` is used for upscaling. If set as + `auto`, `upsample + conv2d` is used for bottom layers until resolution + reaches 128. (default: `auto`) + output_channels: Number of channels of output image. (default: 3) truncation_psi: Style strength multiplier for the truncation trick. - `None` or `1.0` indicates no truncation. + `None` or `1.0` indicates no truncation. (default: 0.7) truncation_layers: Number of layers for which to apply the truncation - trick. `None` indicates no truncation. + trick. `None` indicates no truncation. (default: 8) + randomize_noise: Whether to add random noise for each convolutional layer. + (default: False) Raises: ValueError: If the input `resolution` is not supported. """ super().__init__() - self.mapping = MappingModule(final_space_dim=w_space_dim) - self.truncation = TruncationModule(resolution=resolution, - w_space_dim=w_space_dim, - truncation_psi=truncation_psi, - truncation_layers=truncation_layers) - self.synthesis = SynthesisModule(resolution=resolution, - randomize_noise=randomize_noise) - self.pth_to_tf_var_mapping = _STYLEGAN_PTH_VARS_TO_TF_VARS + self.resolution = resolution + self.w_space_dim = w_space_dim + self.fused_scale = fused_scale + self.output_channels = output_channels + self.truncation_psi = truncation_psi + self.truncation_layers = truncation_layers + self.randomize_noise = randomize_noise + + self.mapping = MappingModule(final_space_dim=self.w_space_dim) + self.truncation = TruncationModule(resolution=self.resolution, + w_space_dim=self.w_space_dim, + truncation_psi=self.truncation_psi, + truncation_layers=self.truncation_layers) + self.synthesis = SynthesisModule(resolution=self.resolution, + fused_scale=self.fused_scale, + output_channels=self.output_channels, + randomize_noise=self.randomize_noise) + + self.pth_to_tf_var_mapping = {} + for pth_var_name, tf_var_name in _STYLEGAN_PTH_VARS_TO_TF_VARS.items(): + if 'Conv0_up' in tf_var_name: + res = int(tf_var_name.split('x')[0]) + if ((self.fused_scale is True) or + (self.fused_scale == 'auto' and res >= _AUTO_FUSED_SCALE_MIN_RES)): + pth_var_name = pth_var_name.replace('conv.weight', 'weight') + self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name def forward(self, z): w = self.mapping(z) @@ -413,41 +336,63 @@ class SynthesisModule(nn.Module): def __init__(self, resolution=1024, + fused_scale='auto', + output_channels=3, randomize_noise=False): super().__init__() try: - channels = _RESOLUTIONS_TO_CHANNELS[resolution] + self.channels = _RESOLUTIONS_TO_CHANNELS[resolution] except KeyError: raise ValueError(f'Invalid resolution: {resolution}!\n' f'Resolutions allowed: ' f'{list(_RESOLUTIONS_TO_CHANNELS)}.') + assert len(self.channels) == int(np.log2(resolution)) - self.num_layers = int(np.log2(resolution)) * 2 - 2 - for i in range(1, len(channels)): - if i == 1: - self.add_module('layer0', FirstConvBlock(channels[0], randomize_noise)) + for block_idx in range(1, len(self.channels)): + if block_idx == 1: + self.add_module( + f'layer{2 * block_idx - 2}', + FirstConvBlock(in_channels=self.channels[block_idx - 1], + randomize_noise=randomize_noise)) else: self.add_module( - f'layer{i * 2 - 2}', - UpConvBlock(layer_idx=i * 2 - 2, - in_channels=channels[i - 1], - out_channels=channels[i], - randomize_noise=randomize_noise)) + f'layer{2 * block_idx - 2}', + UpConvBlock(layer_idx=2 * block_idx - 2, + in_channels=self.channels[block_idx - 1], + out_channels=self.channels[block_idx], + randomize_noise=randomize_noise, + fused_scale=fused_scale)) self.add_module( - f'layer{i * 2 - 1}', - ConvBlock(layer_idx=i * 2 - 1, - in_channels=channels[i], - out_channels=channels[i], + f'layer{2 * block_idx - 1}', + ConvBlock(layer_idx=2 * block_idx - 1, + in_channels=self.channels[block_idx], + out_channels=self.channels[block_idx], randomize_noise=randomize_noise)) - self.add_module('output', LastConvBlock(channels[-1])) + self.add_module( + f'output{block_idx - 1}', + LastConvBlock(in_channels=self.channels[block_idx], + out_channels=output_channels)) + + self.upsample = ResolutionScalingLayer() + self.lod = nn.Parameter(torch.zeros(())) def forward(self, w): + lod = self.lod.cpu().tolist() x = self.layer0(w[:, 0]) - for i in range(1, self.num_layers): - x = self.__getattr__(f'layer{i}')(x, w[:, i]) - x = self.output(x) - return x + for block_idx in range(1, len(self.channels)): + if block_idx + lod < len(self.channels): + layer_idx = 2 * block_idx - 2 + if layer_idx == 0: + x = self.__getattr__(f'layer{layer_idx}')(w[:, layer_idx]) + else: + x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) + layer_idx = 2 * block_idx - 1 + x = self.__getattr__(f'layer{layer_idx}')(x, w[:, layer_idx]) + image = self.__getattr__(f'output{block_idx - 1}')(x) + else: + image = self.upsample(image) + return image class PixelNormLayer(nn.Module): @@ -628,11 +573,11 @@ class FirstConvBlock(nn.Module): Basically, this block starts from a const input, which is `ones(512, 4, 4)`. """ - def __init__(self, channels, randomize_noise=False): + def __init__(self, in_channels, randomize_noise=False): super().__init__() - self.first_layer = nn.Parameter(torch.ones(1, channels, 4, 4)) + self.first_layer = nn.Parameter(torch.ones(1, in_channels, 4, 4)) self.epilogue = EpilogueBlock(layer_idx=0, - channels=channels, + channels=in_channels, randomize_noise=randomize_noise) def forward(self, w): @@ -657,6 +602,7 @@ def __init__(self, padding=1, dilation=1, add_bias=False, + fused_scale='auto', wscale_gain=np.sqrt(2.0), wscale_lr_multiplier=1.0, randomize_noise=False): @@ -670,22 +616,33 @@ def __init__(self, padding: Padding parameter for convolution operation. dilation: Dilation rate for convolution operation. add_bias: Whether to add bias onto the convolutional result. + fused_scale: Whether to fuse `upsample` and `conv2d` together, resulting + in `conv2d_transpose`. wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. + randomize_noise: Whether to add random noise. Raises: ValueError: If the block is not applied to the first block for a - particular resolution. + particular resolution. Or `fused_scale` does not belong to [True, False, + `auto`]. """ super().__init__() if layer_idx % 2 == 1: raise ValueError(f'This block is implemented as the first block of each ' f'resolution, but is applied to layer {layer_idx}!') + if fused_scale not in [True, False, 'auto']: + raise ValueError(f'`fused_scale` can only be [True, False, `auto`], ' + f'but {fused_scale} received!') - self.layer_idx = layer_idx + cur_res = 2 ** (layer_idx // 2 + 2) + if fused_scale == 'auto': + self.fused_scale = (cur_res >= _AUTO_FUSED_SCALE_MIN_RES) + else: + self.fused_scale = fused_scale - if self.layer_idx > 9: + if self.fused_scale: self.weight = nn.Parameter( torch.randn(kernel_size, kernel_size, in_channels, out_channels)) @@ -708,7 +665,7 @@ def __init__(self, randomize_noise=randomize_noise) def forward(self, x, w): - if self.layer_idx > 9: + if self.fused_scale: kernel = self.weight * self.scale kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0) kernel = (kernel[1:, 1:] + kernel[:-1, 1:] + @@ -755,6 +712,7 @@ def __init__(self, wscale_gain: The gain factor for `wscale` layer. wscale_lr_multiplier: The learning rate multiplier factor for `wscale` layer. + randomize_noise: Whether to add random noise. Raises: ValueError: If the block is not applied to the second block for a @@ -791,13 +749,13 @@ class LastConvBlock(nn.Module): Basically, this block converts the final feature map to RGB image. """ - def __init__(self, channels): + def __init__(self, in_channels, out_channels=3): super().__init__() - self.conv = nn.Conv2d(in_channels=channels, - out_channels=3, + self.conv = nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, kernel_size=1, bias=False) - self.scale = 1 / np.sqrt(channels) + self.scale = 1 / np.sqrt(in_channels) self.bias = nn.Parameter(torch.zeros(3)) def forward(self, x):