Skip to content

Commit

Permalink
init pub for aaai24
Browse files Browse the repository at this point in the history
  • Loading branch information
XiYe20 committed Feb 28, 2024
1 parent 98a8537 commit 4040a88
Show file tree
Hide file tree
Showing 40 changed files with 5,823 additions and 0 deletions.
150 changes: 150 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@


# STDiff: Spatio-temporal diffusion for continuous stochastic video prediction</h1>
[arXiv](https://arxiv.org/abs/2312.06486) | [code](https://github.com/XiYe20/STDiffProject)


<h3 align="center"> <img src="./documentations/STDiff_BAIR_15.gif" alt="STDiff_BAIR_15"> </h3>

## Overview
<p align="center">
<img src="./documentations/NN_arch.png" alt="STDiff Architecture" width="100%">
</p>

## Installation
1. Install the custom diffusers library
```bash
git clone https://github.com/XiYe20/CustomDiffusers.git
cd CustomDiffusers
pip install -e .
```
2. Install the requirements of STDiff
```bash
pip install -r requirements.txt
```

## Datasets

Processed KTH dataset: https://drive.google.com/file/d/1RbJyGrYdIp4ROy8r0M-lLAbAMxTRQ-sd/view?usp=sharing \
SM-MNIST: https://drive.google.com/file/d/1eSpXRojBjvE4WoIgeplUznFyRyI3X64w/view?usp=drive_link

For other datasets, please download them from the official website. Here we show the dataset folder structure.

#### BAIR
Please download the original BAIR dataset and utilize the "/utils/read_BAIR_tfrecords.py" script to convert it into frames as follows:

/BAIR \
&nbsp;&nbsp;&nbsp;&nbsp; test/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; example_0/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0000.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0001.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; example_1/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0000.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0001.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; example_... \
&nbsp;&nbsp;&nbsp;&nbsp; train/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; example_0/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0000.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0001.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; example_...

#### Cityscapes
Please download "leftImg8bit_sequence_trainvaltest.zip" from the official website. Center crop and resize all the frames to the size of 128X128. Save all the frames as follows:

/Cityscapes \
&nbsp;&nbsp;&nbsp;&nbsp; test/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; berlin/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; berlin_000000_000000_leftImg8bit.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; berlin_000000_000001_leftImg8bit.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; bielefeld/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; bielefeld_000000_000302_leftImg8bit.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; bielefeld_000000_000302_leftImg8bit.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp; train/\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; aachen/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; .... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; bochum/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; .... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp; val/\
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ....

#### KITTI
Please download the raw data (synced+rectified) from KITTI official website. Center crop and resize all the frames to the resolution of 128X128.
Save all the frames as follows:

/KITTI \
&nbsp;&nbsp;&nbsp;&nbsp; 2011_09_26_drive_0001_sync/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0000000000.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; 0000000001.png \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp; 2011_09_26_drive_0002_sync/ \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ... \
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; ...

## Training and Evaluation
The STDiff project uses accelerate for training. The training configuration files and test configuration files for different datasets are placed inside stdiff/configs.

### Training
1. Check train_script.sh, modify the visible gpus, num_process, select the correct train_cofig file
2. Training
```bash
. ./train_script.sh
```

### Test
1. Check test_script.sh, select the correct test_cofig file
2. Test
```bash
. ./test_script.sh
```

## Citation
```
@article{ye2023stdiff,
title={STDiff: Spatio-temporal Diffusion for Continuous Stochastic Video Prediction},
author={Ye, Xi and Bilodeau, Guillaume-Alexandre},
journal={arXiv preprint arXiv:2312.06486},
year={2023}
}
```




<h2 align="left"> Uncurated prediction examples of STDiff for multiple datasets. </h2>
The temporal coordinates are shown at the top left corner of the frame. <em>Frames with <span style="color:red"> Red temporal coordinates </span> denote future frames predicted by our model.</em>


<h3 align="left"> BAIR </h3>

<h3 align="center"> <img src="./documentations/STDiff_BAIR_0.gif" alt="STDiff_BAIR_0"> </h3>

<h3 align="center"> <img src="./documentations/STDiff_BAIR_15.gif" alt="STDiff_BAIR_15"> </h3>


<h3 align="left"> SMMNIST </h3>

<h3 align="center"> <img src="./documentations/STDiff_SMMNIST_7.gif" alt="STDiff_SMMNIST_7"> </h3>

<h3 align="center"> <img src="./documentations/STDiff_SMMNIST_10.gif" alt="STDiff_SMMNIST_10"> </h3>


<h3 align="left"> KITTI </h3>

<h3 align="center"> <img src="./documentations/STDiff_KITTI_0.gif" alt="STDiff_KITTI_0"> </h3>

<h3 align="center"> <img src="./documentations/STDiff_KITTI_22.gif" alt="STDiff_KITTI_22"> </h3>


<h3 align="left"> Cityscapes </h3>

<h3 align="center"> <img src="./documentations/STDiff_City_110.gif" alt="STDiff_City_110"> </h3>

<h3 align="center"> <img src="./documentations/STDiff_City_120.gif" alt="STDiff_City_120"> </h3>

Binary file added documentations/NN_arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_BAIR_0.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_BAIR_15.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_City_110.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_City_120.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_KITTI_0.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_KITTI_22.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_SMMNIST_10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added documentations/STDiff_SMMNIST_7.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
accelerate==0.17.1
einops==0.7.0
hydra-core==1.3.2
lpips==0.1.4
numpy==1.23.5
omegaconf==2.3.0
opencv_python_headless==4.7.0.72
packaging==23.2
Pillow==10.2.0
pytorch_lightning==1.6.3
scipy==1.10.1
scikit-image==0.21.0
torch==2.0.0
torchdiffeq==0.2.3
torchsde==0.2.5
torchvision==0.15.1
tqdm==4.65.0
11 changes: 11 additions & 0 deletions stdiff/configs/accelerate_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: MULTI_GPU
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 1
use_cpu: false
40 changes: 40 additions & 0 deletions stdiff/configs/bair_test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Dataset:
name: 'BAIR'
dir: '/home/travail/xiyex/BAIR'
phase: 'deploy'
dev_set_size: null
batch_size: 64
num_workers: 16
num_channels: 3
image_size: 64
num_observed_frames: 2
num_predict_frames: 10
test_num_observed_frames: 2
test_num_predict_frames: 28
rand_Tp: null
rand_predict: False
half_fps: False

STDiff:
Diffusion:
unet_config:
sample_size: 64
DiffNet:
MotionEncoder:
image_size: 64

TestCfg:
ckpt_path: "/home/travail/xiyex/STDiff_ckpts/bair_sde_64"
test_results_path: "/home/travail/xiyex/STDiff_ckpts/bair_sde_64/test_ddpm100"
scheduler:
name: 'DDPM' #'DPMMS' or 'DDPM'
sample_steps: 100

fps: 1
metrics: ['PSNR', 'SSIM', 'LPIPS', 'InterLPIPS']
random_predict:
first_pred_sample_num: 10
first_pred_parralle_bs: 4
sample_num: 10
fix_init_noise: False

96 changes: 96 additions & 0 deletions stdiff/configs/bair_train_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
Env:
output_dir: '/home/travail/xiyex/STDiff_ckpts/bair_sde_64'
logger: 'tensorboard'
resume_ckpt: null
stdiff_init_ckpt: null


Dataset:
name: 'BAIR'
dir: './BAIR'
phase: 'deploy'
dev_set_size: null
batch_size: 6
num_workers: 32
num_channels: 3
image_size: 64
num_observed_frames: 2
num_predict_frames: 10
test_num_observed_frames: 2
test_num_predict_frames: 10
rand_Tp: 6
rand_predict: True
half_fps: False

STDiff:
Diffusion:
prediction_type: 'epsilon' #'epsilon' or 'sample'
ddpm_num_steps: 1000
ddpm_num_inference_steps: 300
ddpm_beta_schedule: 'linear'

unet_config:
sample_size: 64
in_channels: 6
out_channels: 3
m_channels: 256
layers_per_block: 2
#config for resolution 128
#block_out_channels: [256, 256, 512, 768, 1024]
#down_block_types: ["DownBlock2D","DownBlock2D","DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"]
#up_block_types: ["AttnUpBlock2D", "AttnUpBlock2D","UpBlock2D", "UpBlock2D", "UpBlock2D"]
#attention_head_dim: [null, null, null, 192, 256]

#config for resolution 64
block_out_channels: [128, 256, 256, 512, 512]
down_block_types: ["DownBlock2D","AttnDownBlock2D","AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"]
up_block_types: ["AttnUpBlock2D", "AttnUpBlock2D","AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"]
attention_head_dim: [null, 128, 128, 128, 128]

DiffNet:
autoregressive: True
super_res_training: False
MotionEncoder:
learn_diff_image: True
image_size: 64
in_channels: 3
model_channels: 64
n_downs: 2
DiffUnet:
n_layers: 2
nonlinear: 'tanh'
Int:
sde: True
method: 'euler_heun'
sde_options:
noise_type: 'diagonal'
sde_type: "stratonovich" #"Stratonovich"
dt: 0.1
rtol: 1e-3
atol: 1e-3
adaptive: False
ode_options:
step_size: 0.1
norm: null

Training:
use_ema: True
ema_inv_gamma: 1.0
ema_power: 0.75
ema_max_decay: 0.9999

learning_rate: 1e-4
lr_scheduler: 'cosine_with_restarts'
lr_warmup_steps: 500
num_cycles: 2
adam_betas: [0.95, 0.999]
adam_weight_decay: 1e-6
adam_epsilon: 1e-8

epochs: 800
save_images_epochs: 4
save_model_epochs: 2
checkpointing_steps: 10 #number of steps to save a resuming checkpoint

gradient_accumulation_steps: 1 #4 for 128 resolution, 4 GPU training
mixed_precision: "no" #["no", "fp16", "bf16"],
40 changes: 40 additions & 0 deletions stdiff/configs/cityscapes_test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Dataset:
name: 'CityScapes'
dir: '/home/travail/xiyex/CityScapes'
phase: 'deploy'
dev_set_size: null
batch_size: 64
num_workers: 16
num_channels: 3
image_size: 128
num_observed_frames: 2
num_predict_frames: 10
test_num_observed_frames: 2
test_num_predict_frames: 28
rand_Tp: null
rand_predict: False
half_fps: False

STDiff:
Diffusion:
unet_config:
sample_size: 128
DiffNet:
MotionEncoder:
image_size: 128

TestCfg:
ckpt_path: "/home/travail/xiyex/STDiff_ckpts/city_sde_128"
test_results_path: "/home/travail/xiyex/STDiff_ckpts/city_sde_128/test_ddpm100"
scheduler:
name: 'DDPM' #'DPMMS' or 'DDPM'
sample_steps: 100

fps: 1
metrics: ['PSNR', 'SSIM', 'LPIPS', 'InterLPIPS']
random_predict:
first_pred_sample_num: 10
first_pred_parralle_bs: 4
sample_num: 10
fix_init_noise: False

Loading

0 comments on commit 4040a88

Please sign in to comment.