-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
40 changed files
with
5,823 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
|
||
|
||
# STDiff: Spatio-temporal diffusion for continuous stochastic video prediction</h1> | ||
[arXiv](https://arxiv.org/abs/2312.06486) | [code](https://github.com/XiYe20/STDiffProject) | ||
|
||
|
||
<h3 align="center"> <img src="./documentations/STDiff_BAIR_15.gif" alt="STDiff_BAIR_15"> </h3> | ||
|
||
## Overview | ||
<p align="center"> | ||
<img src="./documentations/NN_arch.png" alt="STDiff Architecture" width="100%"> | ||
</p> | ||
|
||
## Installation | ||
1. Install the custom diffusers library | ||
```bash | ||
git clone https://github.com/XiYe20/CustomDiffusers.git | ||
cd CustomDiffusers | ||
pip install -e . | ||
``` | ||
2. Install the requirements of STDiff | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Datasets | ||
|
||
Processed KTH dataset: https://drive.google.com/file/d/1RbJyGrYdIp4ROy8r0M-lLAbAMxTRQ-sd/view?usp=sharing \ | ||
SM-MNIST: https://drive.google.com/file/d/1eSpXRojBjvE4WoIgeplUznFyRyI3X64w/view?usp=drive_link | ||
|
||
For other datasets, please download them from the official website. Here we show the dataset folder structure. | ||
|
||
#### BAIR | ||
Please download the original BAIR dataset and utilize the "/utils/read_BAIR_tfrecords.py" script to convert it into frames as follows: | ||
|
||
/BAIR \ | ||
test/ \ | ||
example_0/ \ | ||
0000.png \ | ||
0001.png \ | ||
... \ | ||
example_1/ \ | ||
0000.png \ | ||
0001.png \ | ||
... \ | ||
example_... \ | ||
train/ \ | ||
example_0/ \ | ||
0000.png \ | ||
0001.png \ | ||
... \ | ||
example_... | ||
|
||
#### Cityscapes | ||
Please download "leftImg8bit_sequence_trainvaltest.zip" from the official website. Center crop and resize all the frames to the size of 128X128. Save all the frames as follows: | ||
|
||
/Cityscapes \ | ||
test/ \ | ||
berlin/ \ | ||
berlin_000000_000000_leftImg8bit.png \ | ||
berlin_000000_000001_leftImg8bit.png \ | ||
... \ | ||
bielefeld/ \ | ||
bielefeld_000000_000302_leftImg8bit.png \ | ||
bielefeld_000000_000302_leftImg8bit.png \ | ||
... \ | ||
... \ | ||
train/\ | ||
aachen/ \ | ||
.... \ | ||
bochum/ \ | ||
.... \ | ||
... \ | ||
val/\ | ||
.... | ||
|
||
#### KITTI | ||
Please download the raw data (synced+rectified) from KITTI official website. Center crop and resize all the frames to the resolution of 128X128. | ||
Save all the frames as follows: | ||
|
||
/KITTI \ | ||
2011_09_26_drive_0001_sync/ \ | ||
0000000000.png \ | ||
0000000001.png \ | ||
... \ | ||
2011_09_26_drive_0002_sync/ \ | ||
... \ | ||
... | ||
|
||
## Training and Evaluation | ||
The STDiff project uses accelerate for training. The training configuration files and test configuration files for different datasets are placed inside stdiff/configs. | ||
|
||
### Training | ||
1. Check train_script.sh, modify the visible gpus, num_process, select the correct train_cofig file | ||
2. Training | ||
```bash | ||
. ./train_script.sh | ||
``` | ||
|
||
### Test | ||
1. Check test_script.sh, select the correct test_cofig file | ||
2. Test | ||
```bash | ||
. ./test_script.sh | ||
``` | ||
|
||
## Citation | ||
``` | ||
@article{ye2023stdiff, | ||
title={STDiff: Spatio-temporal Diffusion for Continuous Stochastic Video Prediction}, | ||
author={Ye, Xi and Bilodeau, Guillaume-Alexandre}, | ||
journal={arXiv preprint arXiv:2312.06486}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
|
||
|
||
|
||
<h2 align="left"> Uncurated prediction examples of STDiff for multiple datasets. </h2> | ||
The temporal coordinates are shown at the top left corner of the frame. <em>Frames with <span style="color:red"> Red temporal coordinates </span> denote future frames predicted by our model.</em> | ||
|
||
|
||
<h3 align="left"> BAIR </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_BAIR_0.gif" alt="STDiff_BAIR_0"> </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_BAIR_15.gif" alt="STDiff_BAIR_15"> </h3> | ||
|
||
|
||
<h3 align="left"> SMMNIST </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_SMMNIST_7.gif" alt="STDiff_SMMNIST_7"> </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_SMMNIST_10.gif" alt="STDiff_SMMNIST_10"> </h3> | ||
|
||
|
||
<h3 align="left"> KITTI </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_KITTI_0.gif" alt="STDiff_KITTI_0"> </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_KITTI_22.gif" alt="STDiff_KITTI_22"> </h3> | ||
|
||
|
||
<h3 align="left"> Cityscapes </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_City_110.gif" alt="STDiff_City_110"> </h3> | ||
|
||
<h3 align="center"> <img src="./documentations/STDiff_City_120.gif" alt="STDiff_City_120"> </h3> | ||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
accelerate==0.17.1 | ||
einops==0.7.0 | ||
hydra-core==1.3.2 | ||
lpips==0.1.4 | ||
numpy==1.23.5 | ||
omegaconf==2.3.0 | ||
opencv_python_headless==4.7.0.72 | ||
packaging==23.2 | ||
Pillow==10.2.0 | ||
pytorch_lightning==1.6.3 | ||
scipy==1.10.1 | ||
scikit-image==0.21.0 | ||
torch==2.0.0 | ||
torchdiffeq==0.2.3 | ||
torchsde==0.2.5 | ||
torchvision==0.15.1 | ||
tqdm==4.65.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
compute_environment: LOCAL_MACHINE | ||
deepspeed_config: {} | ||
distributed_type: MULTI_GPU | ||
fsdp_config: {} | ||
machine_rank: 0 | ||
main_process_ip: null | ||
main_process_port: null | ||
main_training_function: main | ||
num_machines: 1 | ||
num_processes: 1 | ||
use_cpu: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
Dataset: | ||
name: 'BAIR' | ||
dir: '/home/travail/xiyex/BAIR' | ||
phase: 'deploy' | ||
dev_set_size: null | ||
batch_size: 64 | ||
num_workers: 16 | ||
num_channels: 3 | ||
image_size: 64 | ||
num_observed_frames: 2 | ||
num_predict_frames: 10 | ||
test_num_observed_frames: 2 | ||
test_num_predict_frames: 28 | ||
rand_Tp: null | ||
rand_predict: False | ||
half_fps: False | ||
|
||
STDiff: | ||
Diffusion: | ||
unet_config: | ||
sample_size: 64 | ||
DiffNet: | ||
MotionEncoder: | ||
image_size: 64 | ||
|
||
TestCfg: | ||
ckpt_path: "/home/travail/xiyex/STDiff_ckpts/bair_sde_64" | ||
test_results_path: "/home/travail/xiyex/STDiff_ckpts/bair_sde_64/test_ddpm100" | ||
scheduler: | ||
name: 'DDPM' #'DPMMS' or 'DDPM' | ||
sample_steps: 100 | ||
|
||
fps: 1 | ||
metrics: ['PSNR', 'SSIM', 'LPIPS', 'InterLPIPS'] | ||
random_predict: | ||
first_pred_sample_num: 10 | ||
first_pred_parralle_bs: 4 | ||
sample_num: 10 | ||
fix_init_noise: False | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
Env: | ||
output_dir: '/home/travail/xiyex/STDiff_ckpts/bair_sde_64' | ||
logger: 'tensorboard' | ||
resume_ckpt: null | ||
stdiff_init_ckpt: null | ||
|
||
|
||
Dataset: | ||
name: 'BAIR' | ||
dir: './BAIR' | ||
phase: 'deploy' | ||
dev_set_size: null | ||
batch_size: 6 | ||
num_workers: 32 | ||
num_channels: 3 | ||
image_size: 64 | ||
num_observed_frames: 2 | ||
num_predict_frames: 10 | ||
test_num_observed_frames: 2 | ||
test_num_predict_frames: 10 | ||
rand_Tp: 6 | ||
rand_predict: True | ||
half_fps: False | ||
|
||
STDiff: | ||
Diffusion: | ||
prediction_type: 'epsilon' #'epsilon' or 'sample' | ||
ddpm_num_steps: 1000 | ||
ddpm_num_inference_steps: 300 | ||
ddpm_beta_schedule: 'linear' | ||
|
||
unet_config: | ||
sample_size: 64 | ||
in_channels: 6 | ||
out_channels: 3 | ||
m_channels: 256 | ||
layers_per_block: 2 | ||
#config for resolution 128 | ||
#block_out_channels: [256, 256, 512, 768, 1024] | ||
#down_block_types: ["DownBlock2D","DownBlock2D","DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"] | ||
#up_block_types: ["AttnUpBlock2D", "AttnUpBlock2D","UpBlock2D", "UpBlock2D", "UpBlock2D"] | ||
#attention_head_dim: [null, null, null, 192, 256] | ||
|
||
#config for resolution 64 | ||
block_out_channels: [128, 256, 256, 512, 512] | ||
down_block_types: ["DownBlock2D","AttnDownBlock2D","AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"] | ||
up_block_types: ["AttnUpBlock2D", "AttnUpBlock2D","AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"] | ||
attention_head_dim: [null, 128, 128, 128, 128] | ||
|
||
DiffNet: | ||
autoregressive: True | ||
super_res_training: False | ||
MotionEncoder: | ||
learn_diff_image: True | ||
image_size: 64 | ||
in_channels: 3 | ||
model_channels: 64 | ||
n_downs: 2 | ||
DiffUnet: | ||
n_layers: 2 | ||
nonlinear: 'tanh' | ||
Int: | ||
sde: True | ||
method: 'euler_heun' | ||
sde_options: | ||
noise_type: 'diagonal' | ||
sde_type: "stratonovich" #"Stratonovich" | ||
dt: 0.1 | ||
rtol: 1e-3 | ||
atol: 1e-3 | ||
adaptive: False | ||
ode_options: | ||
step_size: 0.1 | ||
norm: null | ||
|
||
Training: | ||
use_ema: True | ||
ema_inv_gamma: 1.0 | ||
ema_power: 0.75 | ||
ema_max_decay: 0.9999 | ||
|
||
learning_rate: 1e-4 | ||
lr_scheduler: 'cosine_with_restarts' | ||
lr_warmup_steps: 500 | ||
num_cycles: 2 | ||
adam_betas: [0.95, 0.999] | ||
adam_weight_decay: 1e-6 | ||
adam_epsilon: 1e-8 | ||
|
||
epochs: 800 | ||
save_images_epochs: 4 | ||
save_model_epochs: 2 | ||
checkpointing_steps: 10 #number of steps to save a resuming checkpoint | ||
|
||
gradient_accumulation_steps: 1 #4 for 128 resolution, 4 GPU training | ||
mixed_precision: "no" #["no", "fp16", "bf16"], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
Dataset: | ||
name: 'CityScapes' | ||
dir: '/home/travail/xiyex/CityScapes' | ||
phase: 'deploy' | ||
dev_set_size: null | ||
batch_size: 64 | ||
num_workers: 16 | ||
num_channels: 3 | ||
image_size: 128 | ||
num_observed_frames: 2 | ||
num_predict_frames: 10 | ||
test_num_observed_frames: 2 | ||
test_num_predict_frames: 28 | ||
rand_Tp: null | ||
rand_predict: False | ||
half_fps: False | ||
|
||
STDiff: | ||
Diffusion: | ||
unet_config: | ||
sample_size: 128 | ||
DiffNet: | ||
MotionEncoder: | ||
image_size: 128 | ||
|
||
TestCfg: | ||
ckpt_path: "/home/travail/xiyex/STDiff_ckpts/city_sde_128" | ||
test_results_path: "/home/travail/xiyex/STDiff_ckpts/city_sde_128/test_ddpm100" | ||
scheduler: | ||
name: 'DDPM' #'DPMMS' or 'DDPM' | ||
sample_steps: 100 | ||
|
||
fps: 1 | ||
metrics: ['PSNR', 'SSIM', 'LPIPS', 'InterLPIPS'] | ||
random_predict: | ||
first_pred_sample_num: 10 | ||
first_pred_parralle_bs: 4 | ||
sample_num: 10 | ||
fix_init_noise: False | ||
|
Oops, something went wrong.