Skip to content

Official code for the paper 'DCTdiff: Intriguing Properties of Image Generative Modeling in the DCT Space'

License

Notifications You must be signed in to change notification settings

forever208/DCTdiff

Repository files navigation

DCTdiff

Official PyTorch implementation of DCTdiff: Intriguing Properties of Image Generative Modeling in the DCT Space

The repository is based on UViT. We keep the UViT code unchanged in the UViT branch where you can find our baseline implementation.

We will also release the DCTdiff based on DiT in the near future

drawing

Installation

The installation works both for DCTdiff and UViT

conda create -n dctdiff python==3.9
conda activate dctdiff

pip install matplotlib
pip install accelerate==0.33.0   # (auto install pytorch 2.4)
pip install absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1
pip install opencv-python
pip install scipy

# xformers is optional, but it would greatly speed up the attention computation.
pip install -U xformers
pip install torchvision==0.19.0

Pretrained Models

Model FID-50k (50NFE, DPM-Solver) training steps batch size
CIFAR10 (DCTdiff-M/4) 5.54 175k 256
CelebA 64x64 (DCTdiff-S/4) 1.85 225k 256
ImageNet 64x64 (DCTdiff-S/4) 9.93 475k 1024
FFHQ 128x128 (DCTdiff-S/8) 6.46 300K 256
FFHQ 256x256 (DCTdiff-M/8) 5.62 200k 256

Preparation Before Training and Evaluation

Data

Each dataset is organized in a 'folder' format. You can either use jpg or png for the datasets.

But note that, if you use png images for training, your generated images must be saved into png as well before computing FID. In our experiments, we use jpg images and the fid_stats is computed from the jpg image folder.

  • CIFAR-10 32x32: we provide a script (tools/download_cifar10.py) to download CIFAR-10
  • CelebA 64x64: download the dataset, then do center crop to 64x64 using the script tools/dataset_celeba64.py.
  • ImageNet 64x64: download the dataset, use the train folder for training.
  • FFHQ 128x128: download the dataset
  • FFHQ 256x256: download the dataset

Reference statistics for FID

Download fid_stats directory from this link (which contains reference statistics for FID).

Put the downloaded fid_stats into directory assets/fid_stats, the path of fid_stats is set in the script datasets.py for FID comutation in both training and inference.

Using pytorch-FID, you can also

  • generate your own fid_stats for a given dataset
  • compute the FID whenever you need.
python -m pytorch_fid --save-stats path/to/dataset_folder path/to/fid_stats  # generate fid_stats
python -m pytorch_fid path/to/dataset1_folder/to/dataset2_folder  # FID calculation

Parameters (eta, entropy, m*)

We have provided all DCT-related parameters in the config files configs,

If you need to train on the other datasets, please use the script DCT_datasets_statis.py to compute eta and entropy, or determine m*

Training

We use the huggingface accelerate library to help train with distributed data parallel and mixed precision.

We provide all commands to reproduce DCTdiff training in the paper (4xA100 are used for all experiments):

Feel free to change the sampler, NFE and num_samples in the config file.

# CIFAR10 32x32 (DCTdiff-M/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/cifar10_uvit_mid_2by2.py --workdir YOUR_DIR

# CelebA 64x64 (DCTdiff-S/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/celeba64_uvit_small_2by2.py --workdir YOUR_DIR

# ImageNet 64x64 (DCTdiff-S/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/imgnet64_uvit_small_2by2.py --workdir YOUR_DIR

# FFHQ 128x128 (DCTdiff-S/8)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/ffhq128_uvit_small_4by4.py --workdir YOUR_DIR

# FFHQ 256x256 (DCTdiff-M/8)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 train.py --config=configs/ffhq256_uvit_mid_4by4.py --workdir YOUR_DIR

Evaluation (Compute FID)

We use the huggingface accelerate library for efficient inference with mixed precision and multiple gpus. The following is the evaluation command:

We provide all commands to reproduce FID results in the paper:

# CIFAR10 32x32 (DCTdiff-M/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/cifar10_uvit_mid_2by2.py --nnet_path=cifar10_DCTdiff_mid_2by2.pth --output_path YOUR_DIR

# CelebA 64x64 (DCTdiff-S/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/celeba64_uvit_small_2by2.py --nnet_path=celeba64_DCTdiff_small_2by2.pth --output_path YOUR_DIR

# ImageNet 64x64 (DCTdiff-S/4)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/imgnet64_uvit_small_2by2.py --nnet_path=imgnet64_DCTdiff_small_2by2.pth --output_path YOUR_DIR

# FFHQ 128x128 (DCTdiff-S/8)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/ffhq128_uvit_small_4by4.py --nnet_path=ffhq128_DCTdiff_small_4by4.pth --output_path YOUR_DIR

# FFHQ 256x256 (DCTdiff-M/8)
accelerate launch --multi_gpu --num_processes 4 --mixed_precision fp16 eval.py --config=configs/ffhq256_uvit_mid_4by4.py --nnet_path=ffhq256_DCTdiff_mid_4by4.pth --output_path YOUR_DIR

References

If you find the code useful for your research, please consider citing

@article{ning2024dctdiff,
  title={DCTdiff: Intriguing Properties of Image Generative Modeling in the DCT Space},
  author={Ning, Mang and Li, Mingxiao and Su, Jianlin and Jia, Haozhe and Liu, Lanmiao and Bene{\v{s}}, Martin and Salah, Albert Ali and Ertugrul, Itir Onal},
  journal={arXiv preprint arXiv:2412.15032},
  year={2024}
}

About

Official code for the paper 'DCTdiff: Intriguing Properties of Image Generative Modeling in the DCT Space'

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages