SMaRt: Improving GANs with Score Matching Regularity (ICML 2024)
Mengfei Xia, Yujun Shen, Ceyuan Yang, Ran Yi, Wenping Wang, Yong-Jin Liu
[Paper]
Abstract: Generative adversarial networks (GANs) usually struggle in learning from highly diverse data, whose underlying manifold is complex. In this work, we revisit the mathematical foundations of GANs, and theoretically reveal that the native adversarial loss for GAN training is insufficient to fix the problem of subsets with positive Lebesgue measure of the generated data manifold lying out of the real data manifold. Instead, we find that score matching serves as a promising solution to this issue thanks to its capability of persistently pushing the generated data points towards the real data manifold. We thereby propose to improve the optimization of GANs with score matching regularity (SMaRt). Regarding the empirical evidences, we first design a toy example to show that training GANs by the aid of a ground-truth score function can help reproduce the real data distribution more accurately, and then confirm that our approach can consistently boost the synthesis performance of various state-of-the-art GANs on real-world datasets with pre-trained diffusion models acting as the approximate score function. For instance, when training Aurora on the ImageNet 64 × 64 dataset, we manage to improve FID from 8.87 to 7.11, on par with the performance of one-step consistency model.
This repository is developed based on Hammer, where you can find more detailed instructions on installation. Here, we summarize the necessary steps to facilitate reproduction.
-
Environment: CUDA version == 11.1.
-
Install package requirements with
conda
:conda create -n smart python=3.8 # create virtual environment with Python 3.8 conda activate smart pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 --extra-index-url https://download.pytorch.org/whl/cu111 pip install -r requirements/minimal.txt -f https://download.pytorch.org/whl/cu111/torch_stable.html pip install protobuf==3.20 pip install absl-py einops ftfy==6.1.1
First, please download the pre-trained model following the links below.
To synthesize images, you can use the following command
# Synthesize using Aurora on ImageNet64 with SMaRt.
python run_synthesize.py smart_aurora_imagenet64.pth --syn_num 50000
# Synthesize using StyleGAN2 on LSUN Bedroom with SMaRt.
python run_synthesize.py smart_stylegan2_lsun_bedroom256.pth --syn_num 50000
Implementing SMaRt is based on the objective functions below:
$$\mathcal L_{score}=\mathbb E_{\mathbf z,\boldsymbol\epsilon,t}[|\boldsymbol\epsilon_\theta(\alpha_tg_\phi(\mathbf z)+\sigma_t\boldsymbol\epsilon,t)-\boldsymbol\epsilon|2^2]\quad\text{unconditional GAN, Equation (10),}$$ $$\mathcal L{score}=\mathbb E_{\mathbf z,\boldsymbol\epsilon,t,c}[|\boldsymbol\epsilon_\theta(\alpha_tg_\phi(\mathbf z,c)+\sigma_t\boldsymbol\epsilon,c,t)-\boldsymbol\epsilon|_2^2]\quad\text{conditional GAN, Equation (15).}$$
Therefore, it is necessary to implement the score matching objective using pre-trained DPMs. We provide the pseudo-code below for conditional generation:
def forward_step(cur_iter, freq, G, DPM, z, c, t, lambda_score=0.1):
"""Define the forward process of one training step.
Args:
cur_iter: Current iteration, determining whether to use SMaRt.
freq: Frequency to involve SMaRt.
G: The generator module to learn.
DPM: The pre-trained DPM, fixed while training.
z: Random noise inputted to G.
c: Condition inputted to G.
t: Preset timestep for score matching.
lambda_score: Loss weight.
"""
# Directly skip.
if cur_iter % freq != 0:
return None
# Generate fake images with G.
image = G(z, c)
# Forward diffusing process.
noise = torch.randn_like(image)
x_t = alpha_t * image + sigma_t * noise
# Calculate the score matching regularity.
noise_pred = DPM(x_t, c, t)
loss = (noise - noise_pred).square().mean()
return loss * lambda_score
According to Table 6 in Appendix, we provide the empirical value of hyper-parameters used in our experiments.
Dataset | CIFAR10 | ImageNet 64 | ImageNet 128 | LSUN Bedroom |
---|---|---|---|---|
Setting | Conditional | Conditional | Conditional | Unconditional |
Dataset Scale | 50K Images | 1.3M Images | 1.3M Images | 3M Images |
Frequency |
If you find the code useful for your research, please consider citing
@inproceedings{xia2024smart,
title={SMaRt: Improving GANs with Score Matching Regularity},
author={Xia, Mengfei and Shen, Yujun and Yang, Ceyuan and Yi, Ran and Wang, Wenping and Liu, Yong-Jin},
booktitle={International Conference on Machine Learning (ICML)},
year={2024},
}
The project is under MIT License, and is for research purpose ONLY.
We highly appreciate StyleGAN2, Aurora, ADM, EDM, and Hammer for their contributions to the community.