diff --git a/README.md b/README.md index b74c5dc7..0007c4db 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,78 @@ An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models. +## Hourglass transformer experimental branch + +**This branch is under active development. Models of the new type that are trained with it may stop working due to backward incompatible changes.** + +This branch of `k-diffusion` is for testing an experimental model type, `image_transformer_v2`, that uses ideas from [Hourglass Transformer](https://arxiv.org/abs/2110.13711) and [DiT](https://arxiv.org/abs/2212.09748). + +### Requirements + +To use the new model type you will need to install custom CUDA kernels: + +* [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. There is a [shifted window attention](https://arxiv.org/abs/2103.14030) version of the model type which does not require a custom CUDA kernel, but it does not perform as well and is slower to train and inference. + +* [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. It will fall back to plain PyTorch if it is not installed. + +Also, you should make sure your PyTorch installation is capable of using `torch.compile()` (for instance, if you are using Python 3.11, you should use a PyTorch nightly build instead of 2.0). It will fall back to eager mode if `torch.compile()` is not available, but it will be slower and use more memory in training. + +### Usage + +#### Demo + +To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, install [Hugging Face Datasets](https://huggingface.co/docs/datasets/index): + +```sh +pip install datasets +``` + +and run: + +```sh +python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16 +``` + +If you run out of memory, try adding `--checkpointing` or reducing the batch size. If you are using an older GPU (pre-Ampere), omit `--mixed-precision bf16` to train in FP32. It is not recommended to train in FP16. + +If you have NATTEN installed and working (preferred), you can train with neighborhood attention instead of shifted window attention by specifying `--config configs/config_oxford_flowers.json`. + +#### Config file + +In the `"model"` key of the config file: + +1. Set the `"type"` key to `"image_transformer_v2"`. + +1. The base patch size is set by the `"patch_size"` key, like `"patch_size": [4, 4]`. + +1. Model depth for each level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. + +1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension. + +1. The self-attention mechanism for each level of the hierarchy is specified by the `"self_attns"` config key, like: + + ```json + "self_attns": [ + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "global", "d_head": 64}, + ] + ``` + + If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large. + +1. As a fallback if you or your users cannot use NATTEN, you can also train a model with [shifted window attention](https://arxiv.org/abs/2103.14030) at the low levels of the hierarchy. Shifted window attention does not perform as well as neighborhood attention and it is slower to train and inference, but it does not require custom CUDA kernels. Specify it like: + + ```json + "self_attns": [ + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "global", "d_head": 64}, + ] + ``` + + The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type. + ## Installation `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. @@ -38,7 +110,7 @@ $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME ## Enhancements/additional features -- k-diffusion has support for training transformer-based diffusion models (like [DiT](https://arxiv.org/abs/2212.09748) but improved). +- k-diffusion supports a highly efficient hierarchical transformer model type. - k-diffusion supports a soft version of [Min-SNR loss weighting](https://arxiv.org/abs/2303.09556) for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022). diff --git a/config_from_inference.py b/config_from_inference.py new file mode 100755 index 00000000..f83b49f3 --- /dev/null +++ b/config_from_inference.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +"""Extracts the configuration file from a slim inference checkpoint.""" + +import argparse +import json +from pathlib import Path +import sys + +import k_diffusion as K +import safetensors.torch as safetorch + + +def main(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + p.add_argument("checkpoint", type=Path, + help="the inference checkpoint to extract the configuration from") + p.add_argument("--output", "-o", type=Path, + help="the output configuration file") + args = p.parse_args() + + print(f"Loading inference checkpoint {args.checkpoint}...", file=sys.stderr) + metadata = K.utils.get_safetensors_metadata(args.checkpoint) + if "config" not in metadata: + raise ValueError("No configuration found in checkpoint") + + output_path = args.output or args.checkpoint.with_suffix(".json") + + print(f"Saving configuration to {output_path}...", file=sys.stderr) + output_path.write_text(metadata["config"]) + + +if __name__ == "__main__": + main() diff --git a/configs/config_cifar10_transformer.json b/configs/config_cifar10_transformer.json index 1a9ef096..958cd4cf 100644 --- a/configs/config_cifar10_transformer.json +++ b/configs/config_cifar10_transformer.json @@ -1,11 +1,15 @@ { "model": { - "type": "image_transformer_v1", + "type": "image_transformer_v2", "input_channels": 3, "input_size": [32, 32], - "patch_size": [4, 4], - "width": 512, - "depth": 8, + "patch_size": [2, 2], + "depths": [2, 4], + "widths": [256, 512], + "self_attns": [ + {"type": "global"}, + {"type": "global"} + ], "loss_config": "karras", "loss_weighting": "soft-min-snr", "dropout_rate": 0.05, diff --git a/configs/config_mnist_transformer.json b/configs/config_mnist_transformer.json index 0ea11891..564441f3 100644 --- a/configs/config_mnist_transformer.json +++ b/configs/config_mnist_transformer.json @@ -1,11 +1,11 @@ { "model": { - "type": "image_transformer_v1", + "type": "image_transformer_v2", "input_channels": 1, "input_size": [28, 28], "patch_size": [4, 4], - "width": 256, - "depth": 8, + "depths": [8], + "widths": [256], "loss_config": "karras", "loss_weighting": "soft-min-snr", "dropout_rate": 0.05, diff --git a/configs/config_oxford_flowers.json b/configs/config_oxford_flowers.json new file mode 100644 index 00000000..fe3f7594 --- /dev/null +++ b/configs/config_oxford_flowers.json @@ -0,0 +1,47 @@ +{ + "model": { + "type": "image_transformer_v2", + "input_channels": 3, + "input_size": [256, 256], + "patch_size": [4, 4], + "depths": [2, 2, 4], + "widths": [128, 256, 512], + "self_attns": [ + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "global", "d_head": 64} + ], + "loss_config": "karras", + "loss_weighting": "soft-min-snr", + "dropout_rate": [0.0, 0.0, 0.1], + "mapping_dropout_rate": 0.0, + "augment_prob": 0.0, + "sigma_data": 0.5, + "sigma_min": 1e-2, + "sigma_max": 160, + "sigma_sample_density": { + "type": "cosine-interpolated" + } + }, + "dataset": { + "type": "huggingface", + "location": "nelorth/oxford-flowers", + "image_key": "image" + }, + "optimizer": { + "type": "adamw", + "lr": 5e-4, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 1e-3 + }, + "lr_sched": { + "type": "constant", + "warmup": 0.0 + }, + "ema_sched": { + "type": "inverse", + "power": 0.75, + "max_value": 0.9999 + } +} diff --git a/configs/config_oxford_flowers_shifted_window.json b/configs/config_oxford_flowers_shifted_window.json new file mode 100644 index 00000000..1bb2932c --- /dev/null +++ b/configs/config_oxford_flowers_shifted_window.json @@ -0,0 +1,47 @@ +{ + "model": { + "type": "image_transformer_v2", + "input_channels": 3, + "input_size": [256, 256], + "patch_size": [4, 4], + "depths": [2, 2, 4], + "widths": [128, 256, 512], + "self_attns": [ + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "global", "d_head": 64} + ], + "loss_config": "karras", + "loss_weighting": "soft-min-snr", + "dropout_rate": [0.0, 0.0, 0.1], + "mapping_dropout_rate": 0.0, + "augment_prob": 0.0, + "sigma_data": 0.5, + "sigma_min": 1e-2, + "sigma_max": 160, + "sigma_sample_density": { + "type": "cosine-interpolated" + } + }, + "dataset": { + "type": "huggingface", + "location": "nelorth/oxford-flowers", + "image_key": "image" + }, + "optimizer": { + "type": "adamw", + "lr": 5e-4, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 1e-3 + }, + "lr_sched": { + "type": "constant", + "warmup": 0.0 + }, + "ema_sched": { + "type": "inverse", + "power": 0.75, + "max_value": 0.9999 + } +} diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 49a37189..8ab98b6e 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -55,6 +55,28 @@ def load_config(path_or_dict): 'weight_decay': 1e-4, }, } + defaults_image_transformer_v2 = { + 'model': { + 'mapping_width': 256, + 'mapping_depth': 2, + 'mapping_d_ff': None, + 'mapping_cond_dim': 0, + 'mapping_dropout_rate': 0., + 'd_ffs': None, + 'self_attns': None, + 'dropout_rate': None, + 'augment_wrapper': False, + 'skip_stages': 0, + 'has_variance': False, + }, + 'optimizer': { + 'type': 'adamw', + 'lr': 5e-4, + 'betas': [0.9, 0.99], + 'eps': 1e-8, + 'weight_decay': 1e-4, + }, + } defaults = { 'model': { 'sigma_data': 1., @@ -101,6 +123,26 @@ def load_config(path_or_dict): config = merge(defaults_image_transformer_v1, config) if not config['model']['d_ff']: config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05) + elif config['model']['type'] == 'image_transformer_v2': + config = merge(defaults_image_transformer_v2, config) + if not config['model']['mapping_d_ff']: + config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3 + if not config['model']['d_ffs']: + d_ffs = [] + for width in config['model']['widths']: + d_ffs.append(width * 3) + config['model']['d_ffs'] = d_ffs + if not config['model']['self_attns']: + self_attns = [] + default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7} + default_global = {"type": "global", "d_head": 64} + for i in range(len(config['model']['widths'])): + self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global) + config['model']['self_attns'] = self_attns + if config['model']['dropout_rate'] is None: + config['model']['dropout_rate'] = [0.0] * len(config['model']['widths']) + elif isinstance(config['model']['dropout_rate'], float): + config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths']) return merge(defaults, config) @@ -138,6 +180,34 @@ def make_model(config): dropout=config['dropout_rate'], sigma_data=config['sigma_data'], ) + elif config['type'] == 'image_transformer_v2': + assert len(config['widths']) == len(config['depths']) + assert len(config['widths']) == len(config['d_ffs']) + assert len(config['widths']) == len(config['self_attns']) + assert len(config['widths']) == len(config['dropout_rate']) + levels = [] + for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']): + if self_attn['type'] == 'global': + self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) + elif self_attn['type'] == 'neighborhood': + self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) + elif self_attn['type'] == 'shifted-window': + self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) + elif self_attn['type'] == 'none': + self_attn = models.image_transformer_v2.NoAttentionSpec() + else: + raise ValueError(f'unsupported self attention type {self_attn["type"]}') + levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout)) + mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate']) + model = models.ImageTransformerDenoiserModelV2( + levels=levels, + mapping=mapping, + in_channels=config['input_channels'], + out_channels=config['input_channels'], + patch_size=config['patch_size'], + num_classes=num_classes + 1 if num_classes else 0, + mapping_cond_dim=config['mapping_cond_dim'], + ) else: raise ValueError(f'unsupported model type {config["type"]}') return model diff --git a/k_diffusion/models/__init__.py b/k_diffusion/models/__init__.py index 74986c90..2ee22441 100644 --- a/k_diffusion/models/__init__.py +++ b/k_diffusion/models/__init__.py @@ -1,3 +1,5 @@ +from . import flops from .flags import checkpointing, get_checkpointing from .image_v1 import ImageDenoiserModelV1 from .image_transformer_v1 import ImageTransformerDenoiserModelV1 +from .image_transformer_v2 import ImageTransformerDenoiserModelV2 diff --git a/k_diffusion/models/flops.py b/k_diffusion/models/flops.py new file mode 100644 index 00000000..e98e2a81 --- /dev/null +++ b/k_diffusion/models/flops.py @@ -0,0 +1,54 @@ +from contextlib import contextmanager +import math +import threading + + +state = threading.local() +state.flop_counter = None + + +@contextmanager +def flop_counter(enable=True): + try: + old_flop_counter = state.flop_counter + state.flop_counter = FlopCounter() if enable else None + yield state.flop_counter + finally: + state.flop_counter = old_flop_counter + + +class FlopCounter: + def __init__(self): + self.ops = [] + + def op(self, op, *args, **kwargs): + self.ops.append((op, args, kwargs)) + + @property + def flops(self): + flops = 0 + for op, args, kwargs in self.ops: + flops += op(*args, **kwargs) + return flops + + +def op(op, *args, **kwargs): + if getattr(state, "flop_counter", None): + state.flop_counter.op(op, *args, **kwargs) + + +def op_linear(x, weight): + return math.prod(x) * weight[0] + + +def op_attention(q, k, v): + *b, s_q, d_q = q + *b, s_k, d_k = k + *b, s_v, d_v = v + return math.prod(b) * s_q * s_k * (d_q + d_v) + + +def op_natten(q, k, v, kernel_size): + *q_rest, d_q = q + *_, d_v = v + return math.prod(q_rest) * (d_q + d_v) * kernel_size**2 diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py new file mode 100644 index 00000000..f7ac2094 --- /dev/null +++ b/k_diffusion/models/image_transformer_v2.py @@ -0,0 +1,743 @@ +"""k-diffusion transformer diffusion models, version 2.""" + +from dataclasses import dataclass +from functools import lru_cache, reduce +import math +from typing import Union + +from einops import rearrange +import torch +from torch import nn +import torch._dynamo +from torch.nn import functional as F + +from . import flags, flops +from .. import layers +from .axial_rope import make_axial_pos + + +try: + import natten +except ImportError: + natten = None + +try: + import flash_attn +except ImportError: + flash_attn = None + + +if flags.get_use_compile(): + torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) + torch._dynamo.config.suppress_errors = True + + +# Helpers + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + + +def checkpoint(function, *args, **kwargs): + if flags.get_checkpointing(): + kwargs.setdefault("use_reentrant", True) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + else: + return function(*args, **kwargs) + + +def downscale_pos(pos): + pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=2, nw=2) + return torch.mean(pos, dim=-2) + + +# Param tags + +def tag_param(param, tag): + if not hasattr(param, "_tags"): + param._tags = set([tag]) + else: + param._tags.add(tag) + return param + + +def tag_module(module, tag): + for param in module.parameters(): + tag_param(param, tag) + return module + + +def apply_wd(module): + for name, param in module.named_parameters(): + if name.endswith("weight"): + tag_param(param, "wd") + return module + + +def filter_params(function, module): + for param in module.parameters(): + tags = getattr(param, "_tags", set()) + if function(tags): + yield param + + +# Kernels + +@flags.compile_wrap +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@flags.compile_wrap +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + + +@flags.compile_wrap +def scale_for_cosine_sim(q, k, scale, eps): + dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32)) + sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) + sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True) + sqrt_scale = torch.sqrt(scale.to(dtype)) + scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps) + scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps) + return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype) + + +@flags.compile_wrap +def scale_for_cosine_sim_qkv(qkv, scale, eps): + q, k, v = qkv.unbind(2) + q, k = scale_for_cosine_sim(q, k, scale[:, None], eps) + return torch.stack((q, k, v), dim=2) + + +# Layers + +class Linear(nn.Linear): + def forward(self, x): + flops.op(flops.op_linear, x.shape, self.weight.shape) + return super().forward(x) + + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + flops.op(flops.op_linear, x.shape, self.weight.shape) + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, eps=1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = apply_wd(zero_init(Linear(cond_features, features, bias=False))) + tag_module(self.linear, "mapping") + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, None, :] + 1, self.eps) + + +# Rotary position embeddings + +@flags.compile_wrap +def apply_rotary_emb(x, theta, conj=False): + out_dtype = x.dtype + dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) + d = theta.shape[-1] + assert d * 2 <= x.shape[-1] + x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] + x1, x2, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) + cos, sin = torch.cos(theta), torch.sin(theta) + sin = -sin if conj else sin + y1 = x1 * cos - x2 * sin + y2 = x2 * cos + x1 * sin + y1, y2 = y1.to(out_dtype), y2.to(out_dtype) + return torch.cat((y1, y2, x3), dim=-1) + + +@flags.compile_wrap +def _apply_rotary_emb_inplace(x, theta, conj): + dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) + d = theta.shape[-1] + assert d * 2 <= x.shape[-1] + x1, x2 = x[..., :d], x[..., d : d * 2] + x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) + cos, sin = torch.cos(theta), torch.sin(theta) + sin = -sin if conj else sin + y1 = x1_ * cos - x2_ * sin + y2 = x2_ * cos + x1_ * sin + x1.copy_(y1) + x2.copy_(y2) + + +class ApplyRotaryEmbeddingInplace(torch.autograd.Function): + @staticmethod + def forward(x, theta, conj): + _apply_rotary_emb_inplace(x, theta, conj=conj) + return x + + @staticmethod + def setup_context(ctx, inputs, output): + _, theta, conj = inputs + ctx.save_for_backward(theta) + ctx.conj = conj + + @staticmethod + def backward(ctx, grad_output): + theta, = ctx.saved_tensors + _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj) + return grad_output, None, None + + +def apply_rotary_emb_(x, theta): + return ApplyRotaryEmbeddingInplace.apply(x, theta, False) + + +class AxialRoPE(nn.Module): + def __init__(self, dim, n_heads): + super().__init__() + log_min = math.log(math.pi) + log_max = math.log(10.0 * math.pi) + freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp() + self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous()) + + def extra_repr(self): + return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" + + def forward(self, pos): + theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) + theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) + return torch.cat((theta_h, theta_w), dim=-1) + + +# Shifted window attention + +def window(window_size, x): + *b, h, w, c = x.shape + x = torch.reshape( + x, + (*b, h // window_size, window_size, w // window_size, window_size, c), + ) + x = torch.permute( + x, + (*range(len(b)), -5, -3, -4, -2, -1), + ) + return x + + +def unwindow(x): + *b, h, w, wh, ww, c = x.shape + x = torch.permute(x, (*range(len(b)), -5, -3, -4, -2, -1)) + x = torch.reshape(x, (*b, h * wh, w * ww, c)) + return x + + +def shifted_window(window_size, window_shift, x): + x = torch.roll(x, shifts=(window_shift, window_shift), dims=(-2, -3)) + windows = window(window_size, x) + return windows + + +def shifted_unwindow(window_shift, x): + x = unwindow(x) + x = torch.roll(x, shifts=(-window_shift, -window_shift), dims=(-2, -3)) + return x + + +@lru_cache +def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None): + ph_coords = torch.arange(n_h_w, device=device) + pw_coords = torch.arange(n_w_w, device=device) + h_coords = torch.arange(w_h, device=device) + w_coords = torch.arange(w_w, device=device) + patch_h, patch_w, q_h, q_w, k_h, k_w = torch.meshgrid( + ph_coords, + pw_coords, + h_coords, + w_coords, + h_coords, + w_coords, + indexing="ij", + ) + is_top_patch = patch_h == 0 + is_left_patch = patch_w == 0 + q_above_shift = q_h < shift + k_above_shift = k_h < shift + q_left_of_shift = q_w < shift + k_left_of_shift = k_w < shift + m_corner = ( + is_left_patch + & is_top_patch + & (q_left_of_shift == k_left_of_shift) + & (q_above_shift == k_above_shift) + ) + m_left = is_left_patch & ~is_top_patch & (q_left_of_shift == k_left_of_shift) + m_top = ~is_left_patch & is_top_patch & (q_above_shift == k_above_shift) + m_rest = ~is_left_patch & ~is_top_patch + m = m_corner | m_left | m_top | m_rest + return m + + +def apply_window_attention(window_size, window_shift, q, k, v, scale=None): + # prep windows and masks + q_windows = shifted_window(window_size, window_shift, q) + k_windows = shifted_window(window_size, window_shift, k) + v_windows = shifted_window(window_size, window_shift, v) + b, heads, h, w, wh, ww, d_head = q_windows.shape + mask = make_shifted_window_masks(h, w, wh, ww, window_shift, device=q.device) + q_seqs = torch.reshape(q_windows, (b, heads, h, w, wh * ww, d_head)) + k_seqs = torch.reshape(k_windows, (b, heads, h, w, wh * ww, d_head)) + v_seqs = torch.reshape(v_windows, (b, heads, h, w, wh * ww, d_head)) + mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) + + # do the attention here + flops.op(flops.op_attention, q_seqs.shape, k_seqs.shape, v_seqs.shape) + qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale) + + # unwindow + qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head)) + return shifted_unwindow(window_shift, qkv) + + +# Transformer layers + + +def use_flash_2(x): + if not flags.get_use_flash_attention_2(): + return False + if flash_attn is None: + return False + if x.device.type != "cuda": + return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False + return True + + +class SelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}," + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) + theta = self.pos_emb(pos) + if use_flash_2(qkv): + qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) + qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) + theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3) + qkv = apply_rotary_emb_(qkv, theta) + flops_shape = qkv.shape[-5], qkv.shape[-2], qkv.shape[-4], qkv.shape[-1] + flops.op(flops.op_attention, flops_shape, flops_shape, flops_shape) + x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) + x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) + else: + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6) + theta = theta.movedim(-2, -3) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) + flops.op(flops.op_attention, q.shape, k.shape, v.shape) + x = F.scaled_dot_product_attention(q, k, v, scale=1.0) + x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + +class NeighborhoodSelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.kernel_size = kernel_size + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}, kernel_size={self.kernel_size}" + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + theta = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) + if natten is None: + raise ModuleNotFoundError("natten is required for neighborhood attention") + flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) + qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) + a = torch.softmax(qk, dim=-1) + x = natten.functional.natten2dav(a, v, self.kernel_size, 1) + x = rearrange(x, "n nh h w e -> n h w (nh e)") + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + +class ShiftedWindowSelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.window_size = window_size + self.window_shift = window_shift + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}" + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + theta = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) + x = apply_window_attention(self.window_size, self.window_shift, q, k, v, scale=1.0) + x = rearrange(x, "n nh h w e -> n h w (nh e)") + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + +class FeedForwardBlock(nn.Module): + def __init__(self, d_model, d_ff, cond_features, dropout=0.0): + super().__init__() + self.norm = AdaRMSNorm(d_model, cond_features) + self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) + self.dropout = nn.Dropout(dropout) + self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) + + def forward(self, x, cond): + skip = x + x = self.norm(x, cond) + x = self.up_proj(x) + x = self.dropout(x) + x = self.down_proj(x) + return x + skip + + +class GlobalTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0): + super().__init__() + self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + +class NeighborhoodTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, kernel_size, dropout=0.0): + super().__init__() + self.self_attn = NeighborhoodSelfAttentionBlock(d_model, d_head, cond_features, kernel_size, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + +class ShiftedWindowTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, window_size, index, dropout=0.0): + super().__init__() + window_shift = window_size // 2 if index % 2 == 1 else 0 + self.self_attn = ShiftedWindowSelfAttentionBlock(d_model, d_head, cond_features, window_size, window_shift, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + +class NoAttentionTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, cond_features, dropout=0.0): + super().__init__() + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.ff, x, cond) + return x + + +class Level(nn.ModuleList): + def forward(self, x, *args, **kwargs): + for layer in self: + x = layer(x, *args, **kwargs) + return x + + +# Mapping network + +class MappingFeedForwardBlock(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.0): + super().__init__() + self.norm = RMSNorm(d_model) + self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) + self.dropout = nn.Dropout(dropout) + self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) + + def forward(self, x): + skip = x + x = self.norm(x) + x = self.up_proj(x) + x = self.dropout(x) + x = self.down_proj(x) + return x + skip + + +class MappingNetwork(nn.Module): + def __init__(self, n_layers, d_model, d_ff, dropout=0.0): + super().__init__() + self.in_norm = RMSNorm(d_model) + self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) + self.out_norm = RMSNorm(d_model) + + def forward(self, x): + x = self.in_norm(x) + for block in self.blocks: + x = block(x) + x = self.out_norm(x) + return x + + +# Token merging and splitting + +class TokenMerge(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = apply_wd(Linear(in_features * self.h * self.w, out_features, bias=False)) + + def forward(self, x): + x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w) + return self.proj(x) + + +class TokenSplitWithoutSkip(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) + + def forward(self, x): + x = self.proj(x) + return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) + + +class TokenSplit(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) + self.fac = nn.Parameter(torch.ones(1) * 0.5) + + def forward(self, x, skip): + x = self.proj(x) + x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) + return torch.lerp(skip, x, self.fac.to(x.dtype)) + + +# Configuration + +@dataclass +class GlobalAttentionSpec: + d_head: int + + +@dataclass +class NeighborhoodAttentionSpec: + d_head: int + kernel_size: int + + +@dataclass +class ShiftedWindowAttentionSpec: + d_head: int + window_size: int + + +@dataclass +class NoAttentionSpec: + pass + + +@dataclass +class LevelSpec: + depth: int + width: int + d_ff: int + self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec] + dropout: float + + +@dataclass +class MappingSpec: + depth: int + width: int + d_ff: int + dropout: float + + +# Model class + +class ImageTransformerDenoiserModelV2(nn.Module): + def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0): + super().__init__() + self.num_classes = num_classes + + self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size) + + self.time_emb = layers.FourierFeatures(1, mapping.width) + self.time_in_proj = Linear(mapping.width, mapping.width, bias=False) + self.aug_emb = layers.FourierFeatures(9, mapping.width) + self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False) + self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None + self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None + self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout), "mapping") + + self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() + for i, spec in enumerate(levels): + if isinstance(spec.self_attn, GlobalAttentionSpec): + layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout) + elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): + layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout) + elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec): + layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout) + elif isinstance(spec.self_attn, NoAttentionSpec): + layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout) + else: + raise ValueError(f"unsupported self attention spec {spec.self_attn}") + + if i < len(levels) - 1: + self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)])) + self.up_levels.append(Level([layer_factory(i + spec.depth) for i in range(spec.depth)])) + else: + self.mid_level = Level([layer_factory(i) for i in range(spec.depth)]) + + self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) + self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) + + self.out_norm = RMSNorm(levels[0].width) + self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channels, patch_size) + nn.init.zeros_(self.patch_out.proj.weight) + + def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): + wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self) + no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self) + mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self) + mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self) + groups = [ + {"params": list(wd), "lr": base_lr}, + {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0}, + {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale}, + {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0} + ] + return groups + + def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): + # Patching + x = x.movedim(-3, -1) + x = self.patch_in(x) + # TODO: pixel aspect ratio for nonsquare patches + pos = make_axial_pos(x.shape[-3], x.shape[-2], device=x.device).view(x.shape[-3], x.shape[-2], 2) + + # Mapping network + if class_cond is None and self.class_emb is not None: + raise ValueError("class_cond must be specified if num_classes > 0") + if mapping_cond is None and self.mapping_cond_in_proj is not None: + raise ValueError("mapping_cond must be specified if mapping_cond_dim > 0") + + c_noise = torch.log(sigma) / 4 + time_emb = self.time_in_proj(self.time_emb(c_noise[..., None])) + aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond + aug_emb = self.aug_in_proj(self.aug_emb(aug_cond)) + class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0 + mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0 + cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb) + + # Hourglass transformer + skips, poses = [], [] + for down_level, merge in zip(self.down_levels, self.merges): + x = down_level(x, pos, cond) + skips.append(x) + poses.append(pos) + x = merge(x) + pos = downscale_pos(pos) + + x = self.mid_level(x, pos, cond) + + for up_level, split, skip, pos in reversed(list(zip(self.up_levels, self.splits, skips, poses))): + x = split(x, skip) + x = up_level(x, pos, cond) + + # Unpatching + x = self.out_norm(x) + x = self.patch_out(x) + x = x.movedim(-1, -3) + + return x diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 311d9618..946f2da5 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -4,6 +4,7 @@ from pathlib import Path import shutil import threading +import time import urllib import warnings diff --git a/requirements.txt b/requirements.txt index af3c20f7..c62792ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ Pillow safetensors scikit-image scipy -torch>=2.0 +torch>=2.1 torchdiffeq torchsde torchvision diff --git a/setup.cfg b/setup.cfg index 38b5717e..1e73b1f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = k-diffusion -version = 0.1.1.post1 +version = 0.2.0.dev0 author = Katherine Crowson author_email = crowsonkb@gmail.com url = https://github.com/crowsonkb/k-diffusion @@ -23,7 +23,7 @@ install_requires = safetensors scikit-image scipy - torch >= 2.0 + torch >= 2.1 torchdiffeq torchsde torchvision diff --git a/train.py b/train.py index 27182a7e..2c771241 100755 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import math import json from pathlib import Path +import time import accelerate import safetensors.torch as safetorch @@ -17,7 +18,7 @@ from torch import distributed as dist from torch import multiprocessing as mp from torch import optim -from torch.utils import data +from torch.utils import data, flop_counter from torchvision import datasets, transforms, utils from tqdm.auto import tqdm @@ -51,12 +52,14 @@ def main(): p.add_argument('--end-step', type=int, default=None, help='the step to end training at') p.add_argument('--evaluate-every', type=int, default=10000, - help='save a demo grid every this many steps') + help='evaluate every this many steps') + p.add_argument('--evaluate-n', type=int, default=2000, + help='the number of samples to draw to evaluate') + p.add_argument('--evaluate-only', action='store_true', + help='evaluate instead of training') p.add_argument('--evaluate-with', type=str, default='inception', choices=['inception', 'clip', 'dinov2'], help='the feature extractor to use for evaluation') - p.add_argument('--evaluate-n', type=int, default=2000, - help='the number of samples to draw to evaluate') p.add_argument('--gns', action='store_true', help='measure the gradient noise scale (DDP only, disables stratified sampling)') p.add_argument('--grad-accum-steps', type=int, default=1, @@ -117,11 +120,16 @@ def main(): device = accelerator.device unwrap = accelerator.unwrap_model print(f'Process {accelerator.process_index} using device: {device}', flush=True) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + print(f'World size: {accelerator.num_processes}', flush=True) + print(f'Batch size: {args.batch_size * accelerator.num_processes}', flush=True) if args.seed is not None: seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) torch.manual_seed(seeds[accelerator.process_index]) demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) + elapsed = 0.0 inner_model = K.config.make_model(config) inner_model_ema = deepcopy(inner_model) @@ -150,6 +158,13 @@ def main(): betas=tuple(opt_config['betas']), eps=opt_config['eps'], weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'adam8bit': + import bitsandbytes as bnb + opt = bnb.optim.Adam8bit(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) elif opt_config['type'] == 'sgd': opt = optim.SGD(groups, lr=lr, @@ -221,9 +236,20 @@ def main(): class_key = dataset_config.get('class_key', 1) train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, - num_workers=args.num_workers, persistent_workers=True) + num_workers=args.num_workers, persistent_workers=True, pin_memory=True) inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) + + with torch.no_grad(), K.models.flops.flop_counter() as fc: + x = torch.zeros([1, model_config['input_channels'], size[0], size[1]], device=device) + sigma = torch.ones([1], device=device) + extra_args = {} + if getattr(unwrap(inner_model), "num_classes", 0): + extra_args['class_cond'] = torch.zeros([1], dtype=torch.long, device=device) + inner_model(x, sigma, **extra_args) + if accelerator.is_main_process: + print(f"Forward pass GFLOPs: {fc.flops / 1_000_000_000:,.3f}", flush=True) + if use_wandb: wandb.watch(inner_model) if accelerator.num_processes == 1: @@ -262,6 +288,7 @@ def main(): if args.gns and ckpt.get('gns_stats', None) is not None: gns_stats.load_state_dict(ckpt['gns_stats']) demo_gen.set_state(ckpt['demo_gen']) + elapsed = ckpt.get('elapsed', 0.0) del ckpt else: @@ -283,6 +310,7 @@ def main(): del ckpt evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 + metrics_log = None if evaluate_enabled: if args.evaluate_with == 'inception': extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) @@ -296,8 +324,8 @@ def main(): if accelerator.is_main_process: print('Computing features for reals...') reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) - if accelerator.is_main_process: - metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'loss', 'fid', 'kid']) + if accelerator.is_main_process and not args.evaluate_only: + metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'time', 'loss', 'fid', 'kid']) del train_iter cfg_scale = 1. @@ -361,8 +389,8 @@ def sample_fn(n): fid = K.evaluation.fid(fakes_features, reals_features) kid = K.evaluation.kid(fakes_features, reals_features) print(f'FID: {fid.item():g}, KID: {kid.item():g}') - if accelerator.is_main_process: - metrics_log.write(step, ema_stats['loss'], fid.item(), kid.item()) + if accelerator.is_main_process and metrics_log is not None: + metrics_log.write(step, elapsed, ema_stats['loss'], fid.item(), kid.item()) if use_wandb: wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) @@ -385,6 +413,7 @@ def save(): 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 'ema_stats': ema_stats, 'demo_gen': demo_gen.get_state(), + 'elapsed': elapsed, } accelerator.save(obj, filename) if accelerator.is_main_process: @@ -393,11 +422,25 @@ def save(): if args.wandb_save_model and use_wandb: wandb.save(filename) + if args.evaluate_only: + if not evaluate_enabled: + raise ValueError('--evaluate-only requested but evaluation is disabled') + evaluate() + return + losses_since_last_print = [] try: while True: for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process): + if device.type == 'cuda': + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_timer.record() + else: + start_timer = time.time() + with accelerator.accumulate(model): reals, _, aug_cond = batch[image_key] class_cond, extra_args = None, {} @@ -429,6 +472,13 @@ def save(): K.utils.ema_update(model, model_ema, ema_decay) ema_sched.step() + if device.type == 'cuda': + end_timer.record() + torch.cuda.synchronize() + elapsed += start_timer.elapsed_time(end_timer) / 1000 + else: + elapsed += time.time() - start_timer + if step % 25 == 0: loss_disp = sum(losses_since_last_print) / len(losses_since_last_print) losses_since_last_print.clear()