Skip to content

Commit

Permalink
Merge branch 'transformer-model-v2'
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Jan 21, 2024
2 parents cc49cf6 + 9737cfd commit e1deea1
Show file tree
Hide file tree
Showing 14 changed files with 1,145 additions and 20 deletions.
74 changes: 73 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,78 @@

An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models.

## Hourglass transformer experimental branch

**This branch is under active development. Models of the new type that are trained with it may stop working due to backward incompatible changes.**

This branch of `k-diffusion` is for testing an experimental model type, `image_transformer_v2`, that uses ideas from [Hourglass Transformer](https://arxiv.org/abs/2110.13711) and [DiT](https://arxiv.org/abs/2212.09748).

### Requirements

To use the new model type you will need to install custom CUDA kernels:

* [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. There is a [shifted window attention](https://arxiv.org/abs/2103.14030) version of the model type which does not require a custom CUDA kernel, but it does not perform as well and is slower to train and inference.

* [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. It will fall back to plain PyTorch if it is not installed.

Also, you should make sure your PyTorch installation is capable of using `torch.compile()` (for instance, if you are using Python 3.11, you should use a PyTorch nightly build instead of 2.0). It will fall back to eager mode if `torch.compile()` is not available, but it will be slower and use more memory in training.

### Usage

#### Demo

To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, install [Hugging Face Datasets](https://huggingface.co/docs/datasets/index):

```sh
pip install datasets
```

and run:

```sh
python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16
```

If you run out of memory, try adding `--checkpointing` or reducing the batch size. If you are using an older GPU (pre-Ampere), omit `--mixed-precision bf16` to train in FP32. It is not recommended to train in FP16.

If you have NATTEN installed and working (preferred), you can train with neighborhood attention instead of shifted window attention by specifying `--config configs/config_oxford_flowers.json`.

#### Config file

In the `"model"` key of the config file:

1. Set the `"type"` key to `"image_transformer_v2"`.

1. The base patch size is set by the `"patch_size"` key, like `"patch_size": [4, 4]`.

1. Model depth for each level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level.

1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension.

1. The self-attention mechanism for each level of the hierarchy is specified by the `"self_attns"` config key, like:

```json
"self_attns": [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64},
]
```

If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large.

1. As a fallback if you or your users cannot use NATTEN, you can also train a model with [shifted window attention](https://arxiv.org/abs/2103.14030) at the low levels of the hierarchy. Shifted window attention does not perform as well as neighborhood attention and it is slower to train and inference, but it does not require custom CUDA kernels. Specify it like:

```json
"self_attns": [
{"type": "shifted-window", "d_head": 64, "window_size": 8},
{"type": "shifted-window", "d_head": 64, "window_size": 8},
{"type": "global", "d_head": 64},
]
```

The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type.

## Installation

`k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e <path to repository>`.
Expand Down Expand Up @@ -38,7 +110,7 @@ $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME

## Enhancements/additional features

- k-diffusion has support for training transformer-based diffusion models (like [DiT](https://arxiv.org/abs/2212.09748) but improved).
- k-diffusion supports a highly efficient hierarchical transformer model type.

- k-diffusion supports a soft version of [Min-SNR loss weighting](https://arxiv.org/abs/2303.09556) for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022).

Expand Down
35 changes: 35 additions & 0 deletions config_from_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3

"""Extracts the configuration file from a slim inference checkpoint."""

import argparse
import json
from pathlib import Path
import sys

import k_diffusion as K
import safetensors.torch as safetorch


def main():
p = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
p.add_argument("checkpoint", type=Path,
help="the inference checkpoint to extract the configuration from")
p.add_argument("--output", "-o", type=Path,
help="the output configuration file")
args = p.parse_args()

print(f"Loading inference checkpoint {args.checkpoint}...", file=sys.stderr)
metadata = K.utils.get_safetensors_metadata(args.checkpoint)
if "config" not in metadata:
raise ValueError("No configuration found in checkpoint")

output_path = args.output or args.checkpoint.with_suffix(".json")

print(f"Saving configuration to {output_path}...", file=sys.stderr)
output_path.write_text(metadata["config"])


if __name__ == "__main__":
main()
12 changes: 8 additions & 4 deletions configs/config_cifar10_transformer.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
{
"model": {
"type": "image_transformer_v1",
"type": "image_transformer_v2",
"input_channels": 3,
"input_size": [32, 32],
"patch_size": [4, 4],
"width": 512,
"depth": 8,
"patch_size": [2, 2],
"depths": [2, 4],
"widths": [256, 512],
"self_attns": [
{"type": "global"},
{"type": "global"}
],
"loss_config": "karras",
"loss_weighting": "soft-min-snr",
"dropout_rate": 0.05,
Expand Down
6 changes: 3 additions & 3 deletions configs/config_mnist_transformer.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"model": {
"type": "image_transformer_v1",
"type": "image_transformer_v2",
"input_channels": 1,
"input_size": [28, 28],
"patch_size": [4, 4],
"width": 256,
"depth": 8,
"depths": [8],
"widths": [256],
"loss_config": "karras",
"loss_weighting": "soft-min-snr",
"dropout_rate": 0.05,
Expand Down
47 changes: 47 additions & 0 deletions configs/config_oxford_flowers.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"model": {
"type": "image_transformer_v2",
"input_channels": 3,
"input_size": [256, 256],
"patch_size": [4, 4],
"depths": [2, 2, 4],
"widths": [128, 256, 512],
"self_attns": [
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "neighborhood", "d_head": 64, "kernel_size": 7},
{"type": "global", "d_head": 64}
],
"loss_config": "karras",
"loss_weighting": "soft-min-snr",
"dropout_rate": [0.0, 0.0, 0.1],
"mapping_dropout_rate": 0.0,
"augment_prob": 0.0,
"sigma_data": 0.5,
"sigma_min": 1e-2,
"sigma_max": 160,
"sigma_sample_density": {
"type": "cosine-interpolated"
}
},
"dataset": {
"type": "huggingface",
"location": "nelorth/oxford-flowers",
"image_key": "image"
},
"optimizer": {
"type": "adamw",
"lr": 5e-4,
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": 1e-3
},
"lr_sched": {
"type": "constant",
"warmup": 0.0
},
"ema_sched": {
"type": "inverse",
"power": 0.75,
"max_value": 0.9999
}
}
47 changes: 47 additions & 0 deletions configs/config_oxford_flowers_shifted_window.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"model": {
"type": "image_transformer_v2",
"input_channels": 3,
"input_size": [256, 256],
"patch_size": [4, 4],
"depths": [2, 2, 4],
"widths": [128, 256, 512],
"self_attns": [
{"type": "shifted-window", "d_head": 64, "window_size": 8},
{"type": "shifted-window", "d_head": 64, "window_size": 8},
{"type": "global", "d_head": 64}
],
"loss_config": "karras",
"loss_weighting": "soft-min-snr",
"dropout_rate": [0.0, 0.0, 0.1],
"mapping_dropout_rate": 0.0,
"augment_prob": 0.0,
"sigma_data": 0.5,
"sigma_min": 1e-2,
"sigma_max": 160,
"sigma_sample_density": {
"type": "cosine-interpolated"
}
},
"dataset": {
"type": "huggingface",
"location": "nelorth/oxford-flowers",
"image_key": "image"
},
"optimizer": {
"type": "adamw",
"lr": 5e-4,
"betas": [0.9, 0.95],
"eps": 1e-8,
"weight_decay": 1e-3
},
"lr_sched": {
"type": "constant",
"warmup": 0.0
},
"ema_sched": {
"type": "inverse",
"power": 0.75,
"max_value": 0.9999
}
}
70 changes: 70 additions & 0 deletions k_diffusion/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@ def load_config(path_or_dict):
'weight_decay': 1e-4,
},
}
defaults_image_transformer_v2 = {
'model': {
'mapping_width': 256,
'mapping_depth': 2,
'mapping_d_ff': None,
'mapping_cond_dim': 0,
'mapping_dropout_rate': 0.,
'd_ffs': None,
'self_attns': None,
'dropout_rate': None,
'augment_wrapper': False,
'skip_stages': 0,
'has_variance': False,
},
'optimizer': {
'type': 'adamw',
'lr': 5e-4,
'betas': [0.9, 0.99],
'eps': 1e-8,
'weight_decay': 1e-4,
},
}
defaults = {
'model': {
'sigma_data': 1.,
Expand Down Expand Up @@ -101,6 +123,26 @@ def load_config(path_or_dict):
config = merge(defaults_image_transformer_v1, config)
if not config['model']['d_ff']:
config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05)
elif config['model']['type'] == 'image_transformer_v2':
config = merge(defaults_image_transformer_v2, config)
if not config['model']['mapping_d_ff']:
config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3
if not config['model']['d_ffs']:
d_ffs = []
for width in config['model']['widths']:
d_ffs.append(width * 3)
config['model']['d_ffs'] = d_ffs
if not config['model']['self_attns']:
self_attns = []
default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7}
default_global = {"type": "global", "d_head": 64}
for i in range(len(config['model']['widths'])):
self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global)
config['model']['self_attns'] = self_attns
if config['model']['dropout_rate'] is None:
config['model']['dropout_rate'] = [0.0] * len(config['model']['widths'])
elif isinstance(config['model']['dropout_rate'], float):
config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths'])
return merge(defaults, config)


Expand Down Expand Up @@ -138,6 +180,34 @@ def make_model(config):
dropout=config['dropout_rate'],
sigma_data=config['sigma_data'],
)
elif config['type'] == 'image_transformer_v2':
assert len(config['widths']) == len(config['depths'])
assert len(config['widths']) == len(config['d_ffs'])
assert len(config['widths']) == len(config['self_attns'])
assert len(config['widths']) == len(config['dropout_rate'])
levels = []
for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']):
if self_attn['type'] == 'global':
self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64))
elif self_attn['type'] == 'neighborhood':
self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7))
elif self_attn['type'] == 'shifted-window':
self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size'])
elif self_attn['type'] == 'none':
self_attn = models.image_transformer_v2.NoAttentionSpec()
else:
raise ValueError(f'unsupported self attention type {self_attn["type"]}')
levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout))
mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate'])
model = models.ImageTransformerDenoiserModelV2(
levels=levels,
mapping=mapping,
in_channels=config['input_channels'],
out_channels=config['input_channels'],
patch_size=config['patch_size'],
num_classes=num_classes + 1 if num_classes else 0,
mapping_cond_dim=config['mapping_cond_dim'],
)
else:
raise ValueError(f'unsupported model type {config["type"]}')
return model
Expand Down
2 changes: 2 additions & 0 deletions k_diffusion/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from . import flops
from .flags import checkpointing, get_checkpointing
from .image_v1 import ImageDenoiserModelV1
from .image_transformer_v1 import ImageTransformerDenoiserModelV1
from .image_transformer_v2 import ImageTransformerDenoiserModelV2
Loading

0 comments on commit e1deea1

Please sign in to comment.