Skip to content

Commit

Permalink
Support fused-scale strategy as well as loading from unfinished PGGAN…
Browse files Browse the repository at this point in the history
… and StyleGAN model.
  • Loading branch information
ShenYujun committed Aug 14, 2019
1 parent 52261eb commit d214bc8
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 380 deletions.
15 changes: 11 additions & 4 deletions models/pggan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ def __init__(self, model_name, logger=None):
assert self.gan_type == 'pggan'

def build(self):
self.check_attr('final_tanh')
self.check_attr('fused_scale')
self.model = PGGANGeneratorModel(resolution=self.resolution,
final_tanh=self.final_tanh)
fused_scale=self.fused_scale,
output_channels=self.output_channels)

def load(self):
self.logger.info(f'Loading pytorch model from `{self.model_path}`.')
self.model.load_state_dict(torch.load(self.model_path))
self.logger.info(f'Successfully loaded!')
self.lod = self.model.lod.to(self.cpu_device).tolist()
self.logger.info(f' `lod` of the loaded model is {self.lod}.')

def convert_tf_model(self, test_num=10):
import sys
Expand All @@ -52,12 +55,16 @@ def convert_tf_model(self, test_num=10):
state_dict = self.model.state_dict()
for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items():
if tf_var_name not in tf_vars:
self.logger.debug(f'Variable `{tf_var_name}` does not exist in '
f'tensorflow model.')
continue
self.logger.debug(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(tf_vars[tf_var_name])
var = torch.from_numpy(np.array(tf_vars[tf_var_name]))
if 'weight' in pth_var_name:
if 'layer1.conv' in pth_var_name:
if 'layer0.conv' in pth_var_name:
var = var.view(var.shape[0], -1, 4, 4).permute(1, 0, 2, 3).flip(2, 3)
elif 'Conv0_up' in tf_var_name:
var = var.permute(0, 1, 3, 2)
else:
var = var.permute(3, 2, 0, 1)
state_dict[pth_var_name] = var
Expand Down
236 changes: 156 additions & 80 deletions models/pggan_generator_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
https://arxiv.org/pdf/1710.10196.pdf
"""

from collections import OrderedDict
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['PGGANGeneratorModel']

Expand All @@ -32,99 +32,152 @@

# Variable mapping from pytorch model to official tensorflow model.
_PGGAN_PTH_VARS_TO_TF_VARS = {
'layer1.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4]
'layer1.wscale.bias': '4x4/Dense/bias', # [512]
'layer2.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3]
'layer2.wscale.bias': '4x4/Conv/bias', # [512]
'layer3.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3]
'layer3.wscale.bias': '8x8/Conv0/bias', # [512]
'layer4.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3]
'layer4.wscale.bias': '8x8/Conv1/bias', # [512]
'layer5.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3]
'layer5.wscale.bias': '16x16/Conv0/bias', # [512]
'layer6.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3]
'layer6.wscale.bias': '16x16/Conv1/bias', # [512]
'layer7.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3]
'layer7.wscale.bias': '32x32/Conv0/bias', # [512]
'layer8.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3]
'layer8.wscale.bias': '32x32/Conv1/bias', # [512]
'layer9.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3]
'layer9.wscale.bias': '64x64/Conv0/bias', # [256]
'layer10.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3]
'layer10.wscale.bias': '64x64/Conv1/bias', # [256]
'layer11.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3]
'layer11.wscale.bias': '128x128/Conv0/bias', # [128]
'layer12.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3]
'layer12.wscale.bias': '128x128/Conv1/bias', # [128]
'layer13.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3]
'layer13.wscale.bias': '256x256/Conv0/bias', # [64]
'layer14.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3]
'layer14.wscale.bias': '256x256/Conv1/bias', # [64]
'layer15.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3]
'layer15.wscale.bias': '512x512/Conv0/bias', # [32]
'layer16.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3]
'layer16.wscale.bias': '512x512/Conv1/bias', # [32]
'layer17.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3]
'layer17.wscale.bias': '1024x1024/Conv0/bias', # [16]
'layer18.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3]
'layer18.wscale.bias': '1024x1024/Conv1/bias', # [16]
'output_1024x1024.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1]
'output_1024x1024.wscale.bias': 'ToRGB_lod0/bias', # [3]
'lod': 'lod', # []
'layer0.conv.weight': '4x4/Dense/weight', # [512, 512, 4, 4]
'layer0.wscale.bias': '4x4/Dense/bias', # [512]
'layer1.conv.weight': '4x4/Conv/weight', # [512, 512, 3, 3]
'layer1.wscale.bias': '4x4/Conv/bias', # [512]
'layer2.conv.weight': '8x8/Conv0/weight', # [512, 512, 3, 3]
'layer2.wscale.bias': '8x8/Conv0/bias', # [512]
'layer3.conv.weight': '8x8/Conv1/weight', # [512, 512, 3, 3]
'layer3.wscale.bias': '8x8/Conv1/bias', # [512]
'layer4.conv.weight': '16x16/Conv0/weight', # [512, 512, 3, 3]
'layer4.wscale.bias': '16x16/Conv0/bias', # [512]
'layer5.conv.weight': '16x16/Conv1/weight', # [512, 512, 3, 3]
'layer5.wscale.bias': '16x16/Conv1/bias', # [512]
'layer6.conv.weight': '32x32/Conv0/weight', # [512, 512, 3, 3]
'layer6.wscale.bias': '32x32/Conv0/bias', # [512]
'layer7.conv.weight': '32x32/Conv1/weight', # [512, 512, 3, 3]
'layer7.wscale.bias': '32x32/Conv1/bias', # [512]
'layer8.conv.weight': '64x64/Conv0/weight', # [256, 512, 3, 3]
'layer8.wscale.bias': '64x64/Conv0/bias', # [256]
'layer9.conv.weight': '64x64/Conv1/weight', # [256, 256, 3, 3]
'layer9.wscale.bias': '64x64/Conv1/bias', # [256]
'layer10.conv.weight': '128x128/Conv0/weight', # [128, 256, 3, 3]
'layer10.wscale.bias': '128x128/Conv0/bias', # [128]
'layer11.conv.weight': '128x128/Conv1/weight', # [128, 128, 3, 3]
'layer11.wscale.bias': '128x128/Conv1/bias', # [128]
'layer12.conv.weight': '256x256/Conv0/weight', # [64, 128, 3, 3]
'layer12.wscale.bias': '256x256/Conv0/bias', # [64]
'layer13.conv.weight': '256x256/Conv1/weight', # [64, 64, 3, 3]
'layer13.wscale.bias': '256x256/Conv1/bias', # [64]
'layer14.conv.weight': '512x512/Conv0/weight', # [32, 64, 3, 3]
'layer14.wscale.bias': '512x512/Conv0/bias', # [32]
'layer15.conv.weight': '512x512/Conv1/weight', # [32, 32, 3, 3]
'layer15.wscale.bias': '512x512/Conv1/bias', # [32]
'layer16.conv.weight': '1024x1024/Conv0/weight', # [16, 32, 3, 3]
'layer16.wscale.bias': '1024x1024/Conv0/bias', # [16]
'layer17.conv.weight': '1024x1024/Conv1/weight', # [16, 16, 3, 3]
'layer17.wscale.bias': '1024x1024/Conv1/bias', # [16]
'output0.conv.weight': 'ToRGB_lod8/weight', # [3, 512, 1, 1]
'output0.wscale.bias': 'ToRGB_lod8/bias', # [3]
'output1.conv.weight': 'ToRGB_lod7/weight', # [3, 512, 1, 1]
'output1.wscale.bias': 'ToRGB_lod7/bias', # [3]
'output2.conv.weight': 'ToRGB_lod6/weight', # [3, 512, 1, 1]
'output2.wscale.bias': 'ToRGB_lod6/bias', # [3]
'output3.conv.weight': 'ToRGB_lod5/weight', # [3, 512, 1, 1]
'output3.wscale.bias': 'ToRGB_lod5/bias', # [3]
'output4.conv.weight': 'ToRGB_lod4/weight', # [3, 256, 1, 1]
'output4.wscale.bias': 'ToRGB_lod4/bias', # [3]
'output5.conv.weight': 'ToRGB_lod3/weight', # [3, 128, 1, 1]
'output5.wscale.bias': 'ToRGB_lod3/bias', # [3]
'output6.conv.weight': 'ToRGB_lod2/weight', # [3, 64, 1, 1]
'output6.wscale.bias': 'ToRGB_lod2/bias', # [3]
'output7.conv.weight': 'ToRGB_lod1/weight', # [3, 32, 1, 1]
'output7.wscale.bias': 'ToRGB_lod1/bias', # [3]
'output8.conv.weight': 'ToRGB_lod0/weight', # [3, 16, 1, 1]
'output8.wscale.bias': 'ToRGB_lod0/bias', # [3]
}


class PGGANGeneratorModel(nn.Sequential):
class PGGANGeneratorModel(nn.Module):
"""Defines the generator module in ProgressiveGAN.
Note that the generated images are with RGB color channels with range [-1, 1].
"""

def __init__(self, resolution=1024, final_tanh=False):
def __init__(self,
resolution=1024,
fused_scale=False,
output_channels=3):
"""Initializes the generator with basic settings.
Args:
resolution: The resolution of the final output image. (default: 1024)
final_tanh: Whether to use a `tanh` function to clamp the pixel values of
the output image to range [-1, 1]. (default: False)
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
in `conv2_transpose`. (default: False)
output_channels: Number of channels of the output image. (default: 3)
Raises:
ValueError: If the input `resolution` is not supported.
"""
super().__init__()

try:
channels = _RESOLUTIONS_TO_CHANNELS[resolution]
self.channels = _RESOLUTIONS_TO_CHANNELS[resolution]
except KeyError:
raise ValueError(f'Invalid resolution: {resolution}!\n'
f'Resolutions allowed: '
f'{list(_RESOLUTIONS_TO_CHANNELS)}.')

sequence = OrderedDict()

def _add_layer(layer, name=None):
name = name or f'layer{len(sequence) + 1}'
sequence[name] = layer

_add_layer(ConvBlock(channels[0], channels[1], kernel_size=4, padding=3))
_add_layer(ConvBlock(channels[1], channels[1]))
for i in range(2, len(channels)):
_add_layer(ConvBlock(channels[i-1], channels[i], upsample=True))
_add_layer(ConvBlock(channels[i], channels[i]))
# Final convolutional block.
_add_layer(ConvBlock(in_channels=channels[-1],
out_channels=3,
kernel_size=1,
padding=0,
wscale_gain=1.0,
activation_type='tanh' if final_tanh else 'linear'),
name=f'output_{resolution}x{resolution}')
super().__init__(sequence)
self.pth_to_tf_var_mapping = _PGGAN_PTH_VARS_TO_TF_VARS
assert len(self.channels) == int(np.log2(resolution))

self.resolution = resolution
self.fused_scale = fused_scale
self.output_channels = output_channels

for block_idx in range(1, len(self.channels)):
if block_idx == 1:
self.add_module(
f'layer{2 * block_idx - 2}',
ConvBlock(in_channels=self.channels[block_idx - 1],
out_channels=self.channels[block_idx],
kernel_size=4,
padding=3))
else:
self.add_module(
f'layer{2 * block_idx - 2}',
ConvBlock(in_channels=self.channels[block_idx - 1],
out_channels=self.channels[block_idx],
upsample=True,
fused_scale=self.fused_scale))
self.add_module(
f'layer{2 * block_idx - 1}',
ConvBlock(in_channels=self.channels[block_idx],
out_channels=self.channels[block_idx]))
self.add_module(
f'output{block_idx - 1}',
ConvBlock(in_channels=self.channels[block_idx],
out_channels=self.output_channels,
kernel_size=1,
padding=0,
wscale_gain=1.0,
activation_type='linear'))

self.upsample = ResolutionScalingLayer()
self.lod = nn.Parameter(torch.zeros(()))

self.pth_to_tf_var_mapping = {}
for pth_var_name, tf_var_name in _PGGAN_PTH_VARS_TO_TF_VARS.items():
if self.fused_scale and 'Conv0' in tf_var_name:
pth_var_name = pth_var_name.replace('conv.weight', 'weight')
tf_var_name = tf_var_name.replace('Conv0', 'Conv0_up')
self.pth_to_tf_var_mapping[pth_var_name] = tf_var_name

def forward(self, x):
if len(x.shape) != 2:
raise ValueError(f'The input tensor should be with shape [batch_size, '
f'noise_dim], but {x.shape} received!')
x = x.view(x.shape[0], x.shape[1], 1, 1)
return super().forward(x)

lod = self.lod.cpu().tolist()
for block_idx in range(1, len(self.channels)):
if block_idx + lod < len(self.channels):
x = self.__getattr__(f'layer{2 * block_idx - 2}')(x)
x = self.__getattr__(f'layer{2 * block_idx - 1}')(x)
image = self.__getattr__(f'output{block_idx - 1}')(x)
else:
image = self.upsample(image)
return image


class PixelNormLayer(nn.Module):
Expand All @@ -150,9 +203,7 @@ def __init__(self, scale_factor=2):
self.scale_factor = scale_factor

def forward(self, x):
return nn.functional.interpolate(x,
scale_factor=self.scale_factor,
mode='nearest')
return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')


class WScaleLayer(nn.Module):
Expand Down Expand Up @@ -190,6 +241,7 @@ def __init__(self,
dilation=1,
add_bias=False,
upsample=False,
fused_scale=False,
wscale_gain=np.sqrt(2.0),
activation_type='lrelu'):
"""Initializes the class with block settings.
Expand All @@ -203,6 +255,8 @@ def __init__(self,
dilation: Dilation rate for convolution operation.
add_bias: Whether to add bias onto the convolutional result.
upsample: Whether to upsample the input tensor before convolution.
fused_scale: Whether to fused `upsample` and `conv2d` together, resulting
in `conv2_transpose`.
wscale_gain: The gain factor for `wscale` layer.
wscale_lr_multiplier: The learning rate multiplier factor for `wscale`
layer.
Expand All @@ -214,19 +268,32 @@ def __init__(self,
"""
super().__init__()
self.pixel_norm = PixelNormLayer()
self.upsample = ResolutionScalingLayer() if upsample else (lambda x: x)
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=add_bias)

if upsample and not fused_scale:
self.upsample = ResolutionScalingLayer()
else:
self.upsample = lambda x: x

if upsample and fused_scale:
self.weight = nn.Parameter(
torch.randn(kernel_size, kernel_size, in_channels, out_channels))
fan_in = in_channels * kernel_size * kernel_size
self.scale = wscale_gain / np.sqrt(fan_in)
else:
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=1,
bias=add_bias)

self.wscale = WScaleLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
gain=wscale_gain)

if activation_type == 'linear':
self.activate = (lambda x: x)
elif activation_type == 'lrelu':
Expand All @@ -240,7 +307,16 @@ def __init__(self,
def forward(self, x):
x = self.pixel_norm(x)
x = self.upsample(x)
x = self.conv(x)
if hasattr(self, 'conv'):
x = self.conv(x)
else:
kernel = self.weight * self.scale
kernel = F.pad(kernel, (0, 0, 0, 0, 1, 1, 1, 1), 'constant', 0.0)
kernel = (kernel[1:, 1:] + kernel[:-1, 1:] +
kernel[1:, :-1] + kernel[:-1, :-1])
kernel = kernel.permute(2, 3, 0, 1)
x = F.conv_transpose2d(x, kernel, stride=2, padding=1)
x = x / self.scale
x = self.wscale(x)
x = self.activate(x)
return x
9 changes: 8 additions & 1 deletion models/stylegan_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def __init__(self, model_name, logger=None):

def build(self):
self.check_attr('w_space_dim')
self.check_attr('fused_scale')
self.model = StyleGANGeneratorModel(
resolution=self.resolution,
w_space_dim=self.w_space_dim,
fused_scale=self.fused_scale,
output_channels=self.output_channels,
truncation_psi=self.truncation_psi,
truncation_layers=self.truncation_layers,
randomize_noise=self.randomize_noise)
Expand All @@ -59,6 +62,8 @@ def load(self):
state_dict[var_name] = self.model.state_dict()[var_name]
self.model.load_state_dict(state_dict)
self.logger.info(f'Successfully loaded!')
self.lod = self.model.synthesis.lod.to(self.cpu_device).tolist()
self.logger.info(f' `lod` of the loaded model is {self.lod}.')

def convert_tf_model(self, test_num=10):
import sys
Expand All @@ -82,9 +87,11 @@ def convert_tf_model(self, test_num=10):
state_dict = self.model.state_dict()
for pth_var_name, tf_var_name in self.model.pth_to_tf_var_mapping.items():
if tf_var_name not in tf_vars:
self.logger.debug(f'Variable `{tf_var_name}` does not exist in '
f'tensorflow model.')
continue
self.logger.debug(f' Converting `{tf_var_name}` to `{pth_var_name}`.')
var = torch.from_numpy(tf_vars[tf_var_name])
var = torch.from_numpy(np.array(tf_vars[tf_var_name]))
if 'weight' in pth_var_name:
if 'dense' in pth_var_name:
var = var.permute(1, 0)
Expand Down
Loading

0 comments on commit d214bc8

Please sign in to comment.