From 1d874233b96ea23cf22fb2b7c57004b2df19f2e7 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 21:08:47 +0000 Subject: [PATCH 01/43] Transformer model v2 initial commit --- k_diffusion/config.py | 42 ++ k_diffusion/models/__init__.py | 1 + k_diffusion/models/image_transformer_v2.py | 528 +++++++++++++++++++++ 3 files changed, 571 insertions(+) create mode 100644 k_diffusion/models/image_transformer_v2.py diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 49a3718..42b92a0 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -55,6 +55,21 @@ def load_config(path_or_dict): 'weight_decay': 1e-4, }, } + defaults_image_transformer_v2 = { + 'model': { + 'd_ffs': None, + 'augment_wrapper': False, + 'skip_stages': 0, + 'has_variance': False, + }, + 'optimizer': { + 'type': 'adamw', + 'lr': 5e-4, + 'betas': [0.9, 0.99], + 'eps': 1e-8, + 'weight_decay': 1e-4, + }, + } defaults = { 'model': { 'sigma_data': 1., @@ -101,6 +116,13 @@ def load_config(path_or_dict): config = merge(defaults_image_transformer_v1, config) if not config['model']['d_ff']: config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05) + elif config['model']['type'] == 'image_transformer_v2': + config = merge(defaults_image_transformer_v2, config) + if not config['model']['d_ffs']: + d_ffs = [] + for width in config['model']['widths']: + d_ffs.append(round_to_power_of_two(width * 8 / 3, tol=0.05)) + config['model']['d_ffs'] = d_ffs return merge(defaults, config) @@ -138,6 +160,26 @@ def make_model(config): dropout=config['dropout_rate'], sigma_data=config['sigma_data'], ) + elif config['type'] == 'image_transformer_v2': + assert len(config['widths']) == len(config['depths']) + assert len(config['widths']) == len(config['d_ffs']) + levels = [] + for i, (depth, width, d_ff) in enumerate(zip(config['depths'], config['widths'], config['d_ffs'])): + if i < len(config['depths']) - 1: + self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(64, 7) + else: + self_attn = models.image_transformer_v2.GlobalAttentionSpec(64) + levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn)) + mapping = models.image_transformer_v2.MappingSpec(2, config['widths'][-1], config['d_ffs'][-1]) + model = models.ImageTransformerDenoiserModelV2( + levels=levels, + mapping=mapping, + in_channels=config['input_channels'], + out_channels=config['input_channels'], + patch_size=config['patch_size'], + num_classes=num_classes + 1 if num_classes else 0, + dropout=config['dropout_rate'], + ) else: raise ValueError(f'unsupported model type {config["type"]}') return model diff --git a/k_diffusion/models/__init__.py b/k_diffusion/models/__init__.py index 74986c9..14cb6d2 100644 --- a/k_diffusion/models/__init__.py +++ b/k_diffusion/models/__init__.py @@ -1,3 +1,4 @@ from .flags import checkpointing, get_checkpointing from .image_v1 import ImageDenoiserModelV1 from .image_transformer_v1 import ImageTransformerDenoiserModelV1 +from .image_transformer_v2 import ImageTransformerDenoiserModelV2 diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py new file mode 100644 index 0000000..c020fb7 --- /dev/null +++ b/k_diffusion/models/image_transformer_v2.py @@ -0,0 +1,528 @@ +"""k-diffusion transformer diffusion models, version 2.""" + +from dataclasses import dataclass +from functools import reduce, partial +import math +from typing import Union + +from einops import rearrange +import natten +import torch +from torch import nn +import torch._dynamo +from torch.nn import functional as F + +from . import flags +from .. import layers +from .axial_rope import make_axial_pos + + +try: + import flash_attn + from flash_attn.layers import rotary +except ImportError: + flash_attn = None + rotary = None + + +if flags.get_use_compile(): + torch._dynamo.config.cache_size_limit = max(64, torch._dynamo.config.cache_size_limit) + torch._dynamo.config.suppress_errors = True + + +# Helpers + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + + +def checkpoint(function, *args, **kwargs): + if flags.get_checkpointing(): + kwargs.setdefault("use_reentrant", True) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + else: + return function(*args, **kwargs) + + +def downscale_pos(pos): + pos = rearrange(pos, "... (h nh) (w nw) e -> ... h w (nh nw) e", nh=2, nw=2) + return torch.mean(pos, dim=-2) + + +# Param tags + +def tag_param(param, tag): + if not hasattr(param, "_tags"): + param._tags = set([tag]) + else: + param._tags.add(tag) + return param + + +def tag_module(module, tag): + for param in module.parameters(): + tag_param(param, tag) + return module + + +def apply_wd(module): + for name, param in module.named_parameters(): + if name.endswith("weight"): + tag_param(param, "wd") + return module + + +def filter_params(function, module): + for param in module.parameters(): + tags = getattr(param, "_tags", set()) + if function(tags): + yield param + + +# Kernels + +def compile(function, *args, **kwargs): + if not flags.get_use_compile(): + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + + +@compile +def scale_for_cosine_sim(q, k, scale, eps): + dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32)) + sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) + sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True) + sqrt_scale = torch.sqrt(scale.to(dtype)) + scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps) + scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps) + return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype) + + +@compile +def scale_for_cosine_sim_qkv(qkv, scale, eps): + q, k, v = qkv.unbind(2) + q, k = scale_for_cosine_sim(q, k, scale[:, None], eps) + return torch.stack((q, k, v), dim=2) + + +# Layers + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, eps=1e-6): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False))) + tag_module(self.linear, "mapping") + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, None, :] + 1, self.eps) + + +# Rotary position embeddings + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_emb(x, cos, sin): + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], dim=-1, + ) + + +class AxialRoPE(nn.Module): + def __init__(self, dim): + super().__init__() + freqs = torch.linspace(math.log(math.pi), math.log(10.0 * math.pi), dim // 4 + 1)[:-1].exp() + self.register_buffer("freqs", freqs) + + def extra_repr(self): + return f"dim={self.freqs.shape[-1] * 4}" + + def forward(self, pos): + freqs_h = pos[..., 0:1] * self.freqs.to(pos.dtype) + freqs_w = pos[..., 1:2] * self.freqs.to(pos.dtype) + freqs = torch.cat((freqs_h, freqs_w), dim=-1) + return freqs.cos(), freqs.sin() + + +# Transformer layers + + +def use_flash_2(x, check_dtype=True): + if not flags.get_use_flash_attention_2(): + return False + if flash_attn is None: + return False + if x.device.type != "cuda": + return False + if check_dtype and x.dtype not in (torch.float16, torch.bfloat16): + return False + return True + + +class SelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) + self.pos_emb = AxialRoPE(d_head // 2) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}," + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) + cos, sin = self.pos_emb(pos) + if use_flash_2(qkv): + qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) + qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) + qkv = rotary.apply_rotary_emb_qkv_(qkv, cos, sin) + x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) + x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) + else: + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + x = F.scaled_dot_product_attention(q, k, v, scale=1.0) + x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + +class NeighborhoodSelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.kernel_size = kernel_size + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.pos_emb = AxialRoPE(d_head // 2) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}, kernel_size={self.kernel_size}" + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + if use_flash_2(qkv, check_dtype=False): + qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) + pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) + cos, sin = self.pos_emb(pos) + qkv = rotary.apply_rotary_emb_qkv_(qkv, cos, sin) + q, k, v = rearrange(qkv, "n (h w) t nh e -> t n nh h w e", h=skip.shape[-3], w=skip.shape[-2]) + else: + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + cos, sin = self.pos_emb(pos) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q / math.sqrt(self.d_head) + qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) + a = torch.softmax(qk, dim=-1) + x = natten.functional.natten2dav(a, v, self.kernel_size, 1) + x = rearrange(x, "n nh h w e -> n h w (nh e)") + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + +class FeedForwardBlock(nn.Module): + def __init__(self, d_model, d_ff, cond_features, dropout=0.0): + super().__init__() + self.norm = AdaRMSNorm(d_model, cond_features) + self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) + self.dropout = nn.Dropout(dropout) + self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) + + def forward(self, x, cond): + skip = x + x = self.norm(x, cond) + x = self.up_proj(x) + x = self.dropout(x) + x = self.down_proj(x) + return x + skip + + +class TransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0): + super().__init__() + self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + +class NeighborhoodTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, kernel_size, dropout=0.0): + super().__init__() + self.self_attn = NeighborhoodSelfAttentionBlock(d_model, d_head, cond_features, kernel_size, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + +class Level(nn.ModuleList): + def forward(self, x, *args, **kwargs): + for layer in self: + x = layer(x, *args, **kwargs) + return x + + +# Mapping network + +class MappingFeedForwardBlock(nn.Module): + def __init__(self, d_model, d_ff, dropout=0.0): + super().__init__() + self.norm = RMSNorm(d_model) + self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) + self.dropout = nn.Dropout(dropout) + self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) + + def forward(self, x): + skip = x + x = self.norm(x) + x = self.up_proj(x) + x = self.dropout(x) + x = self.down_proj(x) + return x + skip + + +class MappingNetwork(nn.Module): + def __init__(self, n_layers, d_model, d_ff, dropout=0.0): + super().__init__() + self.in_norm = RMSNorm(d_model) + self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) + self.out_norm = RMSNorm(d_model) + + def forward(self, x): + x = self.in_norm(x) + for block in self.blocks: + x = block(x) + x = self.out_norm(x) + return x + + +# Token merging and splitting + +class TokenMerge(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = nn.Linear(in_features * self.h * self.w, out_features, bias=False) + + def forward(self, x): + x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w) + return self.proj(x) + + +class TokenSplit(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = nn.Linear(in_features, out_features * self.h * self.w, bias=False) + + def forward(self, x): + x = self.proj(x) + return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) + + +# Configuration + +@dataclass +class GlobalAttentionSpec: + d_head: int + + +@dataclass +class NeighborhoodAttentionSpec: + d_head: int + kernel_size: int + + +@dataclass +class LevelSpec: + depth: int + width: int + d_ff: int + self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec] + + +@dataclass +class MappingSpec: + depth: int + width: int + d_ff: int + + +# Model class + +class ImageTransformerDenoiserModelV2(nn.Module): + def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0, dropout=0.0): + super().__init__() + self.num_classes = num_classes + + self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size) + + self.time_emb = layers.FourierFeatures(1, mapping.width) + self.time_in_proj = nn.Linear(mapping.width, mapping.width, bias=False) + self.aug_emb = layers.FourierFeatures(9, mapping.width) + self.aug_in_proj = nn.Linear(mapping.width, mapping.width, bias=False) + self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None + self.mapping_cond_in_proj = nn.Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None + self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=dropout), "mapping") + + self.d_levels, self.u_levels, self.skip_scales = nn.ModuleList(), nn.ModuleList(), nn.ParameterList() + for i, spec in enumerate(levels): + if isinstance(spec.self_attn, GlobalAttentionSpec): + layer_factory = partial(TransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) + elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): + layer_factory = partial(NeighborhoodTransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=dropout) + else: + raise ValueError(f"unsupported self attention spec {spec.self_attn}") + + if i < len(levels) - 1: + self.d_levels.append(Level([layer_factory() for _ in range(spec.depth)])) + self.u_levels.append(Level([layer_factory() for _ in range(spec.depth)])) + else: + self.mid_level = Level([layer_factory() for _ in range(spec.depth)]) + + self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) + self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) + self.skip_scales = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(len(levels) - 1)]) + + self.out_norm = RMSNorm(levels[0].width) + self.patch_out = TokenSplit(levels[0].width, out_channels, patch_size) + nn.init.zeros_(self.patch_out.proj.weight) + + def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): + wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self) + no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self) + mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self) + mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self) + groups = [ + {"params": list(wd), "lr": base_lr}, + {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0}, + {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale}, + {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0} + ] + return groups + + def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): + # Patching + x = x.movedim(-3, -1) + x = self.patch_in(x) + # TODO: pixel aspect ratio for nonsquare patches + pos = make_axial_pos(x.shape[-3], x.shape[-2], device=x.device).view(x.shape[-3], x.shape[-2], 2) + + # Mapping network + if class_cond is None and self.class_emb is not None: + raise ValueError("class_cond must be specified if num_classes > 0") + if mapping_cond is None and self.mapping_cond_in_proj is not None: + raise ValueError("mapping_cond must be specified if mapping_cond_dim > 0") + + c_noise = torch.log(sigma) / 4 + time_emb = self.time_in_proj(self.time_emb(c_noise[..., None])) + aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond + aug_emb = self.aug_in_proj(self.aug_emb(aug_cond)) + class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0 + mapping_emb = self.mapping_cond_in_proj(mapping_cond) if self.mapping_cond_in_proj is not None else 0 + cond = self.mapping(time_emb + aug_emb + class_emb + mapping_emb) + + # Hourglass transformer + skips, poses = [], [] + for d_level, merge in zip(self.d_levels, self.merges): + x = d_level(x, pos, cond) + skips.append(x) + poses.append(pos) + x = merge(x) + pos = downscale_pos(pos) + + x = self.mid_level(x, pos, cond) + + for u_level, split, skip_scale, skip, pos in reversed(list(zip(self.u_levels, self.splits, self.skip_scales, skips, poses))): + x = split(x) + x = torch.addcmul(x, skip, skip_scale) + x = u_level(x, pos, cond) + + # Unpatching + x = self.out_norm(x) + x = self.patch_out(x) + x = x.movedim(-1, -3) + + return x From 48974bcdeed455d1b7848f175b4205b642744ec0 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 21:39:36 +0000 Subject: [PATCH 02/43] If NATTEN is not installed, do not error until it is used --- k_diffusion/models/image_transformer_v2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index c020fb7..3875297 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -6,7 +6,6 @@ from typing import Union from einops import rearrange -import natten import torch from torch import nn import torch._dynamo @@ -17,6 +16,11 @@ from .axial_rope import make_axial_pos +try: + import natten +except ImportError: + natten = None + try: import flash_attn from flash_attn.layers import rotary @@ -285,6 +289,8 @@ def forward(self, x, pos, cond): q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) q = q / math.sqrt(self.d_head) + if natten is None: + raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) a = torch.softmax(qk, dim=-1) x = natten.functional.natten2dav(a, v, self.kernel_size, 1) From 6db5659bdd2404cb82bfc367aba6ce7d1c4ec6a9 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 21:40:07 +0000 Subject: [PATCH 03/43] PyTorch 2.0 compatibility for image_transformer_v2 --- k_diffusion/models/image_transformer_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 3875297..2b2b1a6 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -248,10 +248,10 @@ def forward(self, x, pos, cond): x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) - x = F.scaled_dot_product_attention(q, k, v, scale=1.0) + x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) x = self.out_proj(x) From 53b201317ac08d4906ad27d1c491608bb68ac161 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 22:03:48 +0000 Subject: [PATCH 04/43] Add image_transformer_v2 experimental branch info to README --- README.md | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/README.md b/README.md index d50c579..ea3bdba 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,36 @@ An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch, with enhancements and additional features, such as improved sampling algorithms and transformer-based diffusion models. +## Hourglass transformer experimental branch + +**This branch is under active development.** + +This branch of `k-diffusion` is for testing an experimental model type, `image_transformer_v2`, that uses ideas from [Hourglass Transformer](https://arxiv.org/abs/2110.13711) and [DiT](https://arxiv.org/abs/2212.09748). + +### Requirements + +To use the new model type you will need to install custom CUDA kernels: + +* [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. + +* [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention and rotary position embeddings. It will fall back to plain PyTorch for both of these if it is not installed. + +Also, you should make sure your PyTorch installation is capable of using `torch.compile()` (for instance, if you are using Python 3.11, you should use a PyTorch nightly build instead of 2.0). It will fall back to eager mode if `torch.compile()` is not available, but it will be slower and use more memory in training. + +### Usage + +In the `"model"` key of the config file: + +1. Set the `"type"` key to `"image_transformer_v2"`. + +1. The base patch size is set by the `"patch_size"` key, like `"patch_size": [4, 4]`. + +1. Model depth per level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. + + All levels of the hierarchy except for the highest use sparse (neighborhood) attention with a 7x7 kernel. The highest level uses global attention. So the token count at every level but the highest can be very large. + +1. Model width per level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension, which is 64. + ## Installation `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. From 12ddfe1ac971b3f0e35f6780cbb0d1d0500741a9 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 22:37:12 +0000 Subject: [PATCH 05/43] Reorder imports --- k_diffusion/models/image_transformer_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 2b2b1a6..a56d0fb 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -1,7 +1,7 @@ """k-diffusion transformer diffusion models, version 2.""" from dataclasses import dataclass -from functools import reduce, partial +from functools import partial, reduce import math from typing import Union From c319913e4a59f448132a774756cc5be1c6dc8901 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 26 Sep 2023 22:55:58 +0000 Subject: [PATCH 06/43] Warn that image_transformer_v2 models may stop working due to backward incompatible changes --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ea3bdba..7909d0a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ An implementation of [Elucidating the Design Space of Diffusion-Based Generative ## Hourglass transformer experimental branch -**This branch is under active development.** +**This branch is under active development. Models of the new type that are trained with it may stop working due to backward incompatible changes.** This branch of `k-diffusion` is for testing an experimental model type, `image_transformer_v2`, that uses ideas from [Hourglass Transformer](https://arxiv.org/abs/2110.13711) and [DiT](https://arxiv.org/abs/2212.09748). From c88c97864ddbfebc641de085a6ac76f8de63da6b Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 27 Sep 2023 04:25:25 +0000 Subject: [PATCH 07/43] Use scaled cosine similarity neighborhood attention --- k_diffusion/models/image_transformer_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index a56d0fb..92c1372 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -266,6 +266,7 @@ def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): self.kernel_size = kernel_size self.norm = AdaRMSNorm(d_model, cond_features) self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) self.pos_emb = AxialRoPE(d_head // 2) self.dropout = nn.Dropout(dropout) self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) @@ -281,14 +282,15 @@ def forward(self, x, pos, cond): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) cos, sin = self.pos_emb(pos) + qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) qkv = rotary.apply_rotary_emb_qkv_(qkv, cos, sin) q, k, v = rearrange(qkv, "n (h w) t nh e -> t n nh h w e", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) cos, sin = self.pos_emb(pos) q = apply_rotary_emb(q, cos, sin) k = apply_rotary_emb(k, cos, sin) - q = q / math.sqrt(self.d_head) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) From fe32d7609a700509b106e3c51e4a42ba7eeb4da9 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 27 Sep 2023 04:25:48 +0000 Subject: [PATCH 08/43] Rename d_level and u_level to down_level and up_level --- k_diffusion/models/image_transformer_v2.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 92c1372..420b23d 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -455,7 +455,7 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c self.mapping_cond_in_proj = nn.Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=dropout), "mapping") - self.d_levels, self.u_levels, self.skip_scales = nn.ModuleList(), nn.ModuleList(), nn.ParameterList() + self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() for i, spec in enumerate(levels): if isinstance(spec.self_attn, GlobalAttentionSpec): layer_factory = partial(TransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) @@ -465,8 +465,8 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c raise ValueError(f"unsupported self attention spec {spec.self_attn}") if i < len(levels) - 1: - self.d_levels.append(Level([layer_factory() for _ in range(spec.depth)])) - self.u_levels.append(Level([layer_factory() for _ in range(spec.depth)])) + self.down_levels.append(Level([layer_factory() for _ in range(spec.depth)])) + self.up_levels.append(Level([layer_factory() for _ in range(spec.depth)])) else: self.mid_level = Level([layer_factory() for _ in range(spec.depth)]) @@ -514,8 +514,8 @@ def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): # Hourglass transformer skips, poses = [], [] - for d_level, merge in zip(self.d_levels, self.merges): - x = d_level(x, pos, cond) + for down_level, merge in zip(self.down_levels, self.merges): + x = down_level(x, pos, cond) skips.append(x) poses.append(pos) x = merge(x) @@ -523,10 +523,10 @@ def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): x = self.mid_level(x, pos, cond) - for u_level, split, skip_scale, skip, pos in reversed(list(zip(self.u_levels, self.splits, self.skip_scales, skips, poses))): + for up_level, split, skip_scale, skip, pos in reversed(list(zip(self.up_levels, self.splits, self.skip_scales, skips, poses))): x = split(x) x = torch.addcmul(x, skip, skip_scale) - x = u_level(x, pos, cond) + x = up_level(x, pos, cond) # Unpatching x = self.out_norm(x) From c78347412ae2f524bae2a540155690d5118793a7 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 27 Sep 2023 06:44:36 +0000 Subject: [PATCH 09/43] d_ff = width * 3 --- k_diffusion/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 42b92a0..cdd9ceb 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -121,7 +121,7 @@ def load_config(path_or_dict): if not config['model']['d_ffs']: d_ffs = [] for width in config['model']['widths']: - d_ffs.append(round_to_power_of_two(width * 8 / 3, tol=0.05)) + d_ffs.append(width * 3) config['model']['d_ffs'] = d_ffs return merge(defaults, config) From 8eb5d834b27406e105539485b700f3b534aca6ec Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 27 Sep 2023 08:52:18 +0000 Subject: [PATCH 10/43] Add 'self_attns' config key to image_transformer_v2 --- README.md | 16 +++++++++++++--- k_diffusion/config.py | 18 ++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 7909d0a..f124ae6 100644 --- a/README.md +++ b/README.md @@ -26,11 +26,21 @@ In the `"model"` key of the config file: 1. The base patch size is set by the `"patch_size"` key, like `"patch_size": [4, 4]`. -1. Model depth per level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. +1. Model depth for each level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. - All levels of the hierarchy except for the highest use sparse (neighborhood) attention with a 7x7 kernel. The highest level uses global attention. So the token count at every level but the highest can be very large. +1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension, which is 64. -1. Model width per level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension, which is 64. +1. The self-attention mechanism for each level of the hierarchy is specified by the `"self_attns"` config key, like: + + ```json + "self_attns": [ + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "global", "d_head": 64}, + ] + ``` + + If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large. ## Installation diff --git a/k_diffusion/config.py b/k_diffusion/config.py index cdd9ceb..8dbdf0d 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -58,6 +58,7 @@ def load_config(path_or_dict): defaults_image_transformer_v2 = { 'model': { 'd_ffs': None, + 'self_attns': None, 'augment_wrapper': False, 'skip_stages': 0, 'has_variance': False, @@ -123,6 +124,13 @@ def load_config(path_or_dict): for width in config['model']['widths']: d_ffs.append(width * 3) config['model']['d_ffs'] = d_ffs + if not config['model']['self_attns']: + self_attns = [] + default_neighborhood = {"type": "neighborhood", "d_head": 64, "kernel_size": 7} + default_global = {"type": "global", "d_head": 64} + for i in range(len(config['model']['widths'])): + self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global) + config['model']['self_attns'] = self_attns return merge(defaults, config) @@ -164,11 +172,13 @@ def make_model(config): assert len(config['widths']) == len(config['depths']) assert len(config['widths']) == len(config['d_ffs']) levels = [] - for i, (depth, width, d_ff) in enumerate(zip(config['depths'], config['widths'], config['d_ffs'])): - if i < len(config['depths']) - 1: - self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(64, 7) + for i, (depth, width, d_ff, self_attn) in enumerate(zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'])): + if self_attn['type'] == 'neighborhood': + self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) + elif self_attn['type'] == 'global': + self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) else: - self_attn = models.image_transformer_v2.GlobalAttentionSpec(64) + raise ValueError(f'unsupported self attention type {self_attn["type"]}') levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn)) mapping = models.image_transformer_v2.MappingSpec(2, config['widths'][-1], config['d_ffs'][-1]) model = models.ImageTransformerDenoiserModelV2( From 0c728f69d9395b7a549a71746033ff9b25038dbe Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 27 Sep 2023 13:53:01 +0000 Subject: [PATCH 11/43] Log accumulated wall clock time to metrics log --- k_diffusion/utils.py | 28 ++++++++++++++++++++++++++++ train.py | 9 +++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index 311d961..f14c643 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -4,6 +4,7 @@ from pathlib import Path import shutil import threading +import time import urllib import warnings @@ -455,3 +456,30 @@ def ema_update_dict(values, updates, decay): values[k] *= decay values[k] += (1 - decay) * v return values + + +class Timer: + def __init__(self, elapsed=0.0): + """A simple counter of elapsed time.""" + self.elapsed = elapsed + self.last_time = None + + def get(self, time_value=None): + """Updates and returns the elapsed time.""" + time_value = time_value or time.time() + if self.last_time: + self.elapsed += time_value - self.last_time + self.last_time = time_value + return self.elapsed + + def start(self, time_value=None): + """Starts counting elapsed time.""" + time_value = time_value or time.time() + self.get(time_value) + self.last_time = time_value + + def stop(self, time_value=None): + """Stops counting elapsed time.""" + time_value = time_value or time.time() + self.get(time_value) + self.last_time = None diff --git a/train.py b/train.py index 27182a7..956b1cf 100755 --- a/train.py +++ b/train.py @@ -122,6 +122,7 @@ def main(): seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) torch.manual_seed(seeds[accelerator.process_index]) demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) + timer = K.utils.Timer() inner_model = K.config.make_model(config) inner_model_ema = deepcopy(inner_model) @@ -262,6 +263,7 @@ def main(): if args.gns and ckpt.get('gns_stats', None) is not None: gns_stats.load_state_dict(ckpt['gns_stats']) demo_gen.set_state(ckpt['demo_gen']) + timer = K.utils.Timer(ckpt.get('elapsed', 0.0)) del ckpt else: @@ -297,7 +299,7 @@ def main(): print('Computing features for reals...') reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) if accelerator.is_main_process: - metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'loss', 'fid', 'kid']) + metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'time', 'loss', 'fid', 'kid']) del train_iter cfg_scale = 1. @@ -362,7 +364,7 @@ def sample_fn(n): kid = K.evaluation.kid(fakes_features, reals_features) print(f'FID: {fid.item():g}, KID: {kid.item():g}') if accelerator.is_main_process: - metrics_log.write(step, ema_stats['loss'], fid.item(), kid.item()) + metrics_log.write(step, timer.get(), ema_stats['loss'], fid.item(), kid.item()) if use_wandb: wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) @@ -385,6 +387,7 @@ def save(): 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 'ema_stats': ema_stats, 'demo_gen': demo_gen.get_state(), + 'elapsed': timer.get(), } accelerator.save(obj, filename) if accelerator.is_main_process: @@ -398,6 +401,7 @@ def save(): try: while True: for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process): + timer.start() with accelerator.accumulate(model): reals, _, aug_cond = batch[image_key] class_cond, extra_args = None, {} @@ -451,6 +455,7 @@ def save(): wandb.log(log_dict, step=step) step += 1 + timer.stop() if step % args.demo_every == 0: demo() From 3b1f9e38e6cda7b535f7f4186c2e1d6009168b90 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 00:53:36 +0000 Subject: [PATCH 12/43] in-place apply_rotary_emb_() with backprop --- k_diffusion/models/image_transformer_v2.py | 56 +++++++++++++++------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 420b23d..4726b86 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -188,6 +188,36 @@ def apply_rotary_emb(x, cos, sin): ) +@compile +def _apply_rotary_emb_inplace_forward(x, cos, sin): + return x.copy_(apply_rotary_emb(x, cos, sin)) + + +@compile +def _apply_rotary_emb_inplace_backward(grad_output, cos, sin): + return apply_rotary_emb(grad_output, cos, -sin) + + +class ApplyRotaryEmbeddingInplace(torch.autograd.Function): + @staticmethod + def forward(x, cos, sin): + return _apply_rotary_emb_inplace_forward(x, cos, sin) + + @staticmethod + def setup_context(ctx, inputs, output): + _, cos, sin = inputs + ctx.save_for_backward(cos, sin) + + @staticmethod + def backward(ctx, grad_output): + cos, sin = ctx.saved_tensors + return _apply_rotary_emb_inplace_backward(grad_output, cos, sin), None, None + + +def apply_rotary_emb_(x, cos, sin): + return ApplyRotaryEmbeddingInplace.apply(x, cos, sin) + + class AxialRoPE(nn.Module): def __init__(self, dim): super().__init__() @@ -207,15 +237,13 @@ def forward(self, pos): # Transformer layers -def use_flash_2(x, check_dtype=True): +def use_flash_2(x): if not flags.get_use_flash_attention_2(): return False if flash_attn is None: return False if x.device.type != "cuda": return False - if check_dtype and x.dtype not in (torch.float16, torch.bfloat16): - return False return True @@ -249,8 +277,8 @@ def forward(self, x, pos, cond): else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + q = apply_rotary_emb_(q, cos, sin) + k = apply_rotary_emb_(k, cos, sin) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -278,19 +306,11 @@ def forward(self, x, pos, cond): skip = x x = self.norm(x, cond) qkv = self.qkv_proj(x) - if use_flash_2(qkv, check_dtype=False): - qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) - pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) - cos, sin = self.pos_emb(pos) - qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - qkv = rotary.apply_rotary_emb_qkv_(qkv, cos, sin) - q, k, v = rearrange(qkv, "n (h w) t nh e -> t n nh h w e", h=skip.shape[-3], w=skip.shape[-2]) - else: - q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) - cos, sin = self.pos_emb(pos) - q = apply_rotary_emb(q, cos, sin) - k = apply_rotary_emb(k, cos, sin) + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + cos, sin = self.pos_emb(pos) + q = apply_rotary_emb_(q, cos, sin) + k = apply_rotary_emb_(k, cos, sin) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) From 0530114f23416252c0ba29433c8b9fa7548af70c Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 01:27:55 +0000 Subject: [PATCH 13/43] Use apply_rotary_emb_() on packed qkv --- k_diffusion/models/image_transformer_v2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 4726b86..644377c 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -271,7 +271,9 @@ def forward(self, x, pos, cond): if use_flash_2(qkv): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - qkv = rotary.apply_rotary_emb_qkv_(qkv, cos, sin) + cos = torch.stack((cos, cos, torch.ones_like(cos)), dim=-2).unsqueeze(-2) + sin = torch.stack((sin, sin, torch.zeros_like(sin)), dim=-2).unsqueeze(-2) + qkv = apply_rotary_emb_(qkv, cos, sin) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: From 29d5baa1c1a04ead4fc5f023d14c4efa54e9431f Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 01:42:27 +0000 Subject: [PATCH 14/43] Update kernels section of README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f124ae6..2d9b4ca 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ To use the new model type you will need to install custom CUDA kernels: * [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. -* [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention and rotary position embeddings. It will fall back to plain PyTorch for both of these if it is not installed. +* [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. It will fall back to plain PyTorch if it is not installed. Also, you should make sure your PyTorch installation is capable of using `torch.compile()` (for instance, if you are using Python 3.11, you should use a PyTorch nightly build instead of 2.0). It will fall back to eager mode if `torch.compile()` is not available, but it will be slower and use more memory in training. From 39f27e3d8ec7d225b1f38f07852668cd75e80491 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 17:08:09 +0000 Subject: [PATCH 15/43] Further optimize apply_rotary_emb_() --- k_diffusion/models/image_transformer_v2.py | 24 ++++++++-------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 644377c..cda6d18 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -178,30 +178,22 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_emb(x, cos, sin): +@compile +def _apply_rotary_emb_inplace(x, cos, sin, conjugate=False): ro_dim = cos.shape[-1] * 2 assert ro_dim <= x.shape[-1] + sin = -sin if conjugate else sin cos = torch.cat((cos, cos), dim=-1) sin = torch.cat((sin, sin), dim=-1) - return torch.cat( - [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim]) * sin, x[..., ro_dim:]], dim=-1, - ) - - -@compile -def _apply_rotary_emb_inplace_forward(x, cos, sin): - return x.copy_(apply_rotary_emb(x, cos, sin)) - - -@compile -def _apply_rotary_emb_inplace_backward(grad_output, cos, sin): - return apply_rotary_emb(grad_output, cos, -sin) + rotated = rotate_half(x[..., :ro_dim]) + x[..., :ro_dim].mul_(cos).addcmul_(rotated, sin) + return x class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod def forward(x, cos, sin): - return _apply_rotary_emb_inplace_forward(x, cos, sin) + return _apply_rotary_emb_inplace(x, cos, sin) @staticmethod def setup_context(ctx, inputs, output): @@ -211,7 +203,7 @@ def setup_context(ctx, inputs, output): @staticmethod def backward(ctx, grad_output): cos, sin = ctx.saved_tensors - return _apply_rotary_emb_inplace_backward(grad_output, cos, sin), None, None + return _apply_rotary_emb_inplace(grad_output, cos, sin, conjugate=True), None, None def apply_rotary_emb_(x, cos, sin): From 65eaf7efa3e4c28aec4d68b6d6d5c2ee6df4e631 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 17:59:31 +0000 Subject: [PATCH 16/43] Further optimize apply_rotary_emb_() --- k_diffusion/models/image_transformer_v2.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index cda6d18..32da12d 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -173,20 +173,13 @@ def forward(self, x, cond): # Rotary position embeddings -def rotate_half(x): - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -@compile def _apply_rotary_emb_inplace(x, cos, sin, conjugate=False): - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - sin = -sin if conjugate else sin - cos = torch.cat((cos, cos), dim=-1) - sin = torch.cat((sin, sin), dim=-1) - rotated = rotate_half(x[..., :ro_dim]) - x[..., :ro_dim].mul_(cos).addcmul_(rotated, sin) + d = cos.shape[-1] + assert d * 2 <= x.shape[-1] + x1, x2 = x[..., :d], x[..., d : d * 2] + tmp = x1.clone() + x1.mul_(cos).addcmul_(x2, sin, value=1 if conjugate else -1) + x2.mul_(cos).addcmul_(tmp, sin, value=-1 if conjugate else 1) return x From 4c1093c0fb4cbb916b0d042f7ade279e53a5eefc Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 28 Sep 2023 22:56:52 +0000 Subject: [PATCH 17/43] Add NoAttentionSpec and switch to lerp for skips --- k_diffusion/config.py | 2 + k_diffusion/models/image_transformer_v2.py | 43 ++++++++++++++++++---- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 8dbdf0d..7642b10 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -177,6 +177,8 @@ def make_model(config): self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) elif self_attn['type'] == 'global': self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) + elif self_attn['type'] == 'none': + self_attn = models.image_transformer_v2.NoAttentionSpec() else: raise ValueError(f'unsupported self attention type {self_attn["type"]}') levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn)) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 32da12d..4a87d32 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -350,6 +350,16 @@ def forward(self, x, pos, cond): return x +class NoAttentionTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, cond_features, dropout=0.0): + super().__init__() + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.ff, x, cond) + return x + + class Level(nn.ModuleList): def forward(self, x, *args, **kwargs): for layer in self: @@ -405,7 +415,7 @@ def forward(self, x): return self.proj(x) -class TokenSplit(nn.Module): +class TokenSplitWithoutSkip(nn.Module): def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] @@ -417,6 +427,20 @@ def forward(self, x): return rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) +class TokenSplit(nn.Module): + def __init__(self, in_features, out_features, patch_size=(2, 2)): + super().__init__() + self.h = patch_size[0] + self.w = patch_size[1] + self.proj = nn.Linear(in_features, out_features * self.h * self.w, bias=False) + self.fac = nn.Parameter(torch.ones(1) * 0.5) + + def forward(self, x, skip): + x = self.proj(x) + x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) + return torch.lerp(x, skip, self.fac.to(x.dtype)) + + # Configuration @dataclass @@ -430,12 +454,17 @@ class NeighborhoodAttentionSpec: kernel_size: int +@dataclass +class NoAttentionSpec: + pass + + @dataclass class LevelSpec: depth: int width: int d_ff: int - self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec] + self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, NoAttentionSpec] @dataclass @@ -468,6 +497,8 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c layer_factory = partial(TransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): layer_factory = partial(NeighborhoodTransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=dropout) + elif isinstance(spec.self_attn, NoAttentionSpec): + layer_factory = partial(NoAttentionTransformerLayer, spec.width, spec.d_ff, mapping.width, dropout=dropout) else: raise ValueError(f"unsupported self attention spec {spec.self_attn}") @@ -479,10 +510,9 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) - self.skip_scales = nn.ParameterList([nn.Parameter(torch.ones(1)) for _ in range(len(levels) - 1)]) self.out_norm = RMSNorm(levels[0].width) - self.patch_out = TokenSplit(levels[0].width, out_channels, patch_size) + self.patch_out = TokenSplitWithoutSkip(levels[0].width, out_channels, patch_size) nn.init.zeros_(self.patch_out.proj.weight) def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): @@ -530,9 +560,8 @@ def forward(self, x, sigma, aug_cond=None, class_cond=None, mapping_cond=None): x = self.mid_level(x, pos, cond) - for up_level, split, skip_scale, skip, pos in reversed(list(zip(self.up_levels, self.splits, self.skip_scales, skips, poses))): - x = split(x) - x = torch.addcmul(x, skip, skip_scale) + for up_level, split, skip, pos in reversed(list(zip(self.up_levels, self.splits, skips, poses))): + x = split(x, skip) x = up_level(x, pos, cond) # Unpatching From df43aab3a122d370e2731233350d8c0bfd64bb0a Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Fri, 29 Sep 2023 23:54:54 +0000 Subject: [PATCH 18/43] Conjugate rotary --- k_diffusion/models/image_transformer_v2.py | 46 ++++++++++++++-------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 4a87d32..325da96 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -23,10 +23,8 @@ try: import flash_attn - from flash_attn.layers import rotary except ImportError: flash_attn = None - rotary = None if flags.get_use_compile(): @@ -172,35 +170,49 @@ def forward(self, x, cond): # Rotary position embeddings +@compile +def apply_rotary_emb(x, cos, sin, conj=False): + out_dtype = x.dtype + dtype = reduce(torch.promote_types, (x.dtype, cos.dtype, sin.dtype, torch.float32)) + d = cos.shape[-1] + assert d * 2 <= x.shape[-1] + x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] + x1, x2, cos, sin = x1.to(dtype), x2.to(dtype), cos.to(dtype), sin.to(dtype) + sin = -sin if conj else sin + y1 = (x1 * cos - x2 * sin).to(out_dtype) + y2 = (x2 * cos + x1 * sin).to(out_dtype) + return torch.cat((y1, y2, x3), dim=-1).to(out_dtype) -def _apply_rotary_emb_inplace(x, cos, sin, conjugate=False): + +def _apply_rotary_emb_inplace(x, cos, sin, conj): d = cos.shape[-1] assert d * 2 <= x.shape[-1] x1, x2 = x[..., :d], x[..., d : d * 2] tmp = x1.clone() - x1.mul_(cos).addcmul_(x2, sin, value=1 if conjugate else -1) - x2.mul_(cos).addcmul_(tmp, sin, value=-1 if conjugate else 1) + x1.mul_(cos).addcmul_(x2, sin, value=1 if conj else -1) + x2.mul_(cos).addcmul_(tmp, sin, value=-1 if conj else 1) return x class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod - def forward(x, cos, sin): - return _apply_rotary_emb_inplace(x, cos, sin) + def forward(x, cos, sin, conj): + return _apply_rotary_emb_inplace(x, cos, sin, conj=conj) @staticmethod def setup_context(ctx, inputs, output): - _, cos, sin = inputs + _, cos, sin, conj = inputs ctx.save_for_backward(cos, sin) + ctx.conj = conj @staticmethod def backward(ctx, grad_output): cos, sin = ctx.saved_tensors - return _apply_rotary_emb_inplace(grad_output, cos, sin, conjugate=True), None, None + return _apply_rotary_emb_inplace(grad_output, cos, sin, conj=not ctx.conj), None, None, None -def apply_rotary_emb_(x, cos, sin): - return ApplyRotaryEmbeddingInplace.apply(x, cos, sin) +def apply_rotary_emb_(x, cos, sin, conj=False): + return ApplyRotaryEmbeddingInplace.apply(x, cos, sin, conj) class AxialRoPE(nn.Module): @@ -229,6 +241,8 @@ def use_flash_2(x): return False if x.device.type != "cuda": return False + if x.dtype not in (torch.float16, torch.bfloat16): + return False return True @@ -258,14 +272,14 @@ def forward(self, x, pos, cond): qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) cos = torch.stack((cos, cos, torch.ones_like(cos)), dim=-2).unsqueeze(-2) sin = torch.stack((sin, sin, torch.zeros_like(sin)), dim=-2).unsqueeze(-2) - qkv = apply_rotary_emb_(qkv, cos, sin) + qkv = apply_rotary_emb_(qkv, cos, sin, conj=True) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - q = apply_rotary_emb_(q, cos, sin) - k = apply_rotary_emb_(k, cos, sin) + q = apply_rotary_emb_(q, cos, sin, conj=True) + k = apply_rotary_emb_(k, cos, sin, conj=True) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -296,8 +310,8 @@ def forward(self, x, pos, cond): q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) cos, sin = self.pos_emb(pos) - q = apply_rotary_emb_(q, cos, sin) - k = apply_rotary_emb_(k, cos, sin) + q = apply_rotary_emb_(q, cos, sin, conj=True) + k = apply_rotary_emb_(k, cos, sin, conj=True) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) From ea55f2666ceafb20585b67b1fc78d5b6b37b0929 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Fri, 29 Sep 2023 23:55:20 +0000 Subject: [PATCH 19/43] Reverse lerp direction --- k_diffusion/models/image_transformer_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 325da96..06ba2a9 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -452,7 +452,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): def forward(self, x, skip): x = self.proj(x) x = rearrange(x, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=self.h, nw=self.w) - return torch.lerp(x, skip, self.fac.to(x.dtype)) + return torch.lerp(skip, x, self.fac.to(x.dtype)) # Configuration From 501f7cfa0376e5d402ad1c3a714b5506b246dc15 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 30 Sep 2023 19:04:54 +0000 Subject: [PATCH 20/43] Add shifted window attention --- README.md | 18 ++- k_diffusion/config.py | 11 +- k_diffusion/models/image_transformer_v2.py | 157 +++++++++++++++++++-- 3 files changed, 172 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 2d9b4ca..b06ca05 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ In the `"model"` key of the config file: 1. Model depth for each level of the hierarchy is specified by the `"depths"` config key, like `"depths": [2, 2, 4]`. This constructs a model with two transformer layers at the first level (4x4 patches), followed by two at the second level (8x8 patches), followed by four at the highest level (16x16 patches), followed by two more at the second level, followed by two more at the first level. -1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension, which is 64. +1. Model width for each level of the hierarchy is specified by the `"widths"` config key, like `"widths": [192, 384, 768]`. The widths must be multiples of the attention head dimension. 1. The self-attention mechanism for each level of the hierarchy is specified by the `"self_attns"` config key, like: @@ -42,6 +42,20 @@ In the `"model"` key of the config file: If not specified, all levels of the hierarchy except for the highest use neighborhood attention with 64 dim heads and a 7x7 kernel. The highest level uses global attention with 64 dim heads. So the token count at every level but the highest can be very large. +1. As a fallback if you or your users cannot use NATTEN, you can also train a model with [shifted window attention](https://arxiv.org/abs/2103.14030) at the low levels of the hierarchy. Shifted window attention does not perform as well as neighborhood attention and it is slower to train and inference, but it does not require custom CUDA kernels. Specify it like: + + ```json + "self_attns": [ + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "global", "d_head": 64}, + ] + ``` + + The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type. + +1. FP16 training with this model type is unstable. It is recommended to use BF16 (`--mixed-precision bf16`) or FP32. + ## Installation `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. @@ -76,7 +90,7 @@ $ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME ## Enhancements/additional features -- k-diffusion has support for training transformer-based diffusion models (like [DiT](https://arxiv.org/abs/2212.09748) but improved). +- k-diffusion supports a highly efficient hierarchical transformer model type. - k-diffusion supports a soft version of [Min-SNR loss weighting](https://arxiv.org/abs/2303.09556) for improved training at high resolutions with less hyperparameters than the loss weighting used in Karras et al. (2022). diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 7642b10..96d4622 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -171,12 +171,15 @@ def make_model(config): elif config['type'] == 'image_transformer_v2': assert len(config['widths']) == len(config['depths']) assert len(config['widths']) == len(config['d_ffs']) + assert len(config['widths']) == len(config['self_attns']) levels = [] - for i, (depth, width, d_ff, self_attn) in enumerate(zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'])): - if self_attn['type'] == 'neighborhood': - self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) - elif self_attn['type'] == 'global': + for depth, width, d_ff, self_attn in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns']): + if self_attn['type'] == 'global': self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) + elif self_attn['type'] == 'neighborhood': + self_attn = models.image_transformer_v2.NeighborhoodAttentionSpec(self_attn.get('d_head', 64), self_attn.get('kernel_size', 7)) + elif self_attn['type'] == 'shifted-window': + self_attn = models.image_transformer_v2.ShiftedWindowAttentionSpec(self_attn.get('d_head', 64), self_attn['window_size']) elif self_attn['type'] == 'none': self_attn = models.image_transformer_v2.NoAttentionSpec() else: diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 06ba2a9..e191613 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -1,7 +1,7 @@ """k-diffusion transformer diffusion models, version 2.""" from dataclasses import dataclass -from functools import partial, reduce +from functools import reduce import math from typing import Union @@ -231,6 +231,93 @@ def forward(self, pos): return freqs.cos(), freqs.sin() +# Shifted window attention + +def window(window_size, x): + *b, h, w, c = x.shape + x = torch.reshape( + x, + (*b, h // window_size, window_size, w // window_size, window_size, c), + ) + x = torch.permute( + x, + (*range(len(b)), -5, -3, -4, -2, -1), + ) + return x + + +def unwindow(x): + *b, h, w, wh, ww, c = x.shape + x = torch.permute(x, (*range(len(b)), -5, -3, -4, -2, -1)) + x = torch.reshape(x, (*b, h * wh, w * ww, c)) + return x + + +def shifted_window(window_size, window_shift, x): + x = torch.roll(x, shifts=(window_shift, window_shift), dims=(-2, -3)) + windows = window(window_size, x) + return windows + + +def shifted_unwindow(window_shift, x): + x = unwindow(x) + x = torch.roll(x, shifts=(-window_shift, -window_shift), dims=(-2, -3)) + return x + + +def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None): + ph_coords = torch.arange(n_h_w, device=device) + pw_coords = torch.arange(n_w_w, device=device) + h_coords = torch.arange(w_h, device=device) + w_coords = torch.arange(w_w, device=device) + patch_h, patch_w, q_h, q_w, k_h, k_w = torch.meshgrid( + ph_coords, + pw_coords, + h_coords, + w_coords, + h_coords, + w_coords, + indexing="ij", + ) + is_top_patch = patch_h == 0 + is_left_patch = patch_w == 0 + q_above_shift = q_h < shift + k_above_shift = k_h < shift + q_left_of_shift = q_w < shift + k_left_of_shift = k_w < shift + m_corner = ( + is_left_patch + & is_top_patch + & (q_left_of_shift == k_left_of_shift) + & (q_above_shift == k_above_shift) + ) + m_left = is_left_patch & ~is_top_patch & (q_left_of_shift == k_left_of_shift) + m_top = ~is_left_patch & is_top_patch & (q_above_shift == k_above_shift) + m_rest = ~is_left_patch & ~is_top_patch + m = m_corner | m_left | m_top | m_rest + return m + + +def apply_window_attention(window_size, window_shift, q, k, v): + # prep windows and masks + q_windows = shifted_window(window_size, window_shift, q) + k_windows = shifted_window(window_size, window_shift, k) + v_windows = shifted_window(window_size, window_shift, v) + b, heads, h, w, wh, ww, d_head = q_windows.shape + mask = make_shifted_window_masks(h, w, wh, ww, window_shift, device=q.device) + q_seqs = torch.reshape(q_windows, (b, heads, h, w, wh * ww, d_head)) + k_seqs = torch.reshape(k_windows, (b, heads, h, w, wh * ww, d_head)) + v_seqs = torch.reshape(v_windows, (b, heads, h, w, wh * ww, d_head)) + mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) + + # do the attention here + qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=1.0) + + # unwindow + qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head)) + return shifted_unwindow(window_shift, qkv) + + # Transformer layers @@ -323,6 +410,39 @@ def forward(self, x, pos, cond): return x + skip +class ShiftedWindowSelfAttentionBlock(nn.Module): + def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dropout=0.0): + super().__init__() + self.d_head = d_head + self.n_heads = d_model // d_head + self.window_size = window_size + self.window_shift = window_shift + self.norm = AdaRMSNorm(d_model, cond_features) + self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) + self.pos_emb = AxialRoPE(d_head // 2) + self.dropout = nn.Dropout(dropout) + self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + + def extra_repr(self): + return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}" + + def forward(self, x, pos, cond): + skip = x + x = self.norm(x, cond) + qkv = self.qkv_proj(x) + q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + cos, sin = self.pos_emb(pos) + q = apply_rotary_emb_(q, cos, sin, conj=True) + k = apply_rotary_emb_(k, cos, sin, conj=True) + x = apply_window_attention(self.window_size, self.window_shift, q, k, v) + x = rearrange(x, "n nh h w e -> n h w (nh e)") + x = self.dropout(x) + x = self.out_proj(x) + return x + skip + + class FeedForwardBlock(nn.Module): def __init__(self, d_model, d_ff, cond_features, dropout=0.0): super().__init__() @@ -364,6 +484,19 @@ def forward(self, x, pos, cond): return x +class ShiftedWindowTransformerLayer(nn.Module): + def __init__(self, d_model, d_ff, d_head, cond_features, window_size, index, dropout=0.0): + super().__init__() + window_shift = window_size // 2 if index % 2 == 1 else 0 + self.self_attn = ShiftedWindowSelfAttentionBlock(d_model, d_head, cond_features, window_size, window_shift, dropout=dropout) + self.ff = FeedForwardBlock(d_model, d_ff, cond_features, dropout=dropout) + + def forward(self, x, pos, cond): + x = checkpoint(self.self_attn, x, pos, cond) + x = checkpoint(self.ff, x, cond) + return x + + class NoAttentionTransformerLayer(nn.Module): def __init__(self, d_model, d_ff, cond_features, dropout=0.0): super().__init__() @@ -468,6 +601,12 @@ class NeighborhoodAttentionSpec: kernel_size: int +@dataclass +class ShiftedWindowAttentionSpec: + d_head: int + window_size: int + + @dataclass class NoAttentionSpec: pass @@ -478,7 +617,7 @@ class LevelSpec: depth: int width: int d_ff: int - self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, NoAttentionSpec] + self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec] @dataclass @@ -508,19 +647,21 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() for i, spec in enumerate(levels): if isinstance(spec.self_attn, GlobalAttentionSpec): - layer_factory = partial(TransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) + layer_factory = lambda _: TransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): - layer_factory = partial(NeighborhoodTransformerLayer, spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=dropout) + layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=dropout) + elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec): + layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=dropout) elif isinstance(spec.self_attn, NoAttentionSpec): - layer_factory = partial(NoAttentionTransformerLayer, spec.width, spec.d_ff, mapping.width, dropout=dropout) + layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=dropout) else: raise ValueError(f"unsupported self attention spec {spec.self_attn}") if i < len(levels) - 1: - self.down_levels.append(Level([layer_factory() for _ in range(spec.depth)])) - self.up_levels.append(Level([layer_factory() for _ in range(spec.depth)])) + self.down_levels.append(Level([layer_factory(i) for i in range(spec.depth)])) + self.up_levels.append(Level([layer_factory(i + spec.depth) for i in range(spec.depth)])) else: - self.mid_level = Level([layer_factory() for _ in range(spec.depth)]) + self.mid_level = Level([layer_factory(i) for i in range(spec.depth)]) self.merges = nn.ModuleList([TokenMerge(spec_1.width, spec_2.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) self.splits = nn.ModuleList([TokenSplit(spec_2.width, spec_1.width) for spec_1, spec_2 in zip(levels[:-1], levels[1:])]) From 6dac00688c6f1b06232743ec11942ed80c551f2d Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 30 Sep 2023 19:52:34 +0000 Subject: [PATCH 21/43] Print world size and batch size in train.py --- train.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train.py b/train.py index 956b1cf..6bcab1c 100755 --- a/train.py +++ b/train.py @@ -117,6 +117,10 @@ def main(): device = accelerator.device unwrap = accelerator.unwrap_model print(f'Process {accelerator.process_index} using device: {device}', flush=True) + accelerator.wait_for_everyone() + if accelerator.is_main_process: + print(f'World size: {accelerator.num_processes}', flush=True) + print(f'Batch size: {args.batch_size * accelerator.num_processes}', flush=True) if args.seed is not None: seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) From 0a109b35a41904b0152f88b919ea53f5913523ca Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 30 Sep 2023 22:17:47 +0000 Subject: [PATCH 22/43] Update configs and README --- README.md | 18 ++++++-- configs/config_cifar10_transformer.json | 12 +++-- configs/config_mnist_transformer.json | 6 +-- configs/config_oxford_flowers.json | 46 +++++++++++++++++++ .../config_oxford_flowers_shifted_window.json | 46 +++++++++++++++++++ 5 files changed, 118 insertions(+), 10 deletions(-) create mode 100644 configs/config_oxford_flowers.json create mode 100644 configs/config_oxford_flowers_shifted_window.json diff --git a/README.md b/README.md index b06ca05..5a88b71 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ This branch of `k-diffusion` is for testing an experimental model type, `image_t To use the new model type you will need to install custom CUDA kernels: -* [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. +* [NATTEN](https://github.com/SHI-Labs/NATTEN/tree/main) for the sparse (neighborhood) attention used at low levels of the hierarchy. There is a [shifted window attention](https://arxiv.org/abs/2103.14030) version of the model type which does not require a custom CUDA kernel, but it does not perform as well and is slower to train and inference. * [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) for global attention. It will fall back to plain PyTorch if it is not installed. @@ -20,6 +20,20 @@ Also, you should make sure your PyTorch installation is capable of using `torch. ### Usage +#### Demo + +To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, run: + +```sh +python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16 +``` + +If you run out of memory, try adding `--checkpointing` or reducing the batch size. If you are using an older GPU (pre-Ampere), omit `--mixed-precision bf16` to train in FP32. It is not recommended to train in FP16. + +If you have NATTEN installed and working (preferred), you can train with neighborhood attention instead of shifted window attention by specifying `--config configs/config_oxford_flowers.json`. + +#### Config file + In the `"model"` key of the config file: 1. Set the `"type"` key to `"image_transformer_v2"`. @@ -54,8 +68,6 @@ In the `"model"` key of the config file: The window size at each level must evenly divide the image size at that level. Models trained with one attention type must be fine-tuned to be used with a different type. -1. FP16 training with this model type is unstable. It is recommended to use BF16 (`--mixed-precision bf16`) or FP32. - ## Installation `k-diffusion` can be installed via PyPI (`pip install k-diffusion`) but it will not include training and inference scripts, only library code that others can depend on. To run the training and inference scripts, clone this repository and run `pip install -e `. diff --git a/configs/config_cifar10_transformer.json b/configs/config_cifar10_transformer.json index 1a9ef09..958cd4c 100644 --- a/configs/config_cifar10_transformer.json +++ b/configs/config_cifar10_transformer.json @@ -1,11 +1,15 @@ { "model": { - "type": "image_transformer_v1", + "type": "image_transformer_v2", "input_channels": 3, "input_size": [32, 32], - "patch_size": [4, 4], - "width": 512, - "depth": 8, + "patch_size": [2, 2], + "depths": [2, 4], + "widths": [256, 512], + "self_attns": [ + {"type": "global"}, + {"type": "global"} + ], "loss_config": "karras", "loss_weighting": "soft-min-snr", "dropout_rate": 0.05, diff --git a/configs/config_mnist_transformer.json b/configs/config_mnist_transformer.json index 0ea1189..564441f 100644 --- a/configs/config_mnist_transformer.json +++ b/configs/config_mnist_transformer.json @@ -1,11 +1,11 @@ { "model": { - "type": "image_transformer_v1", + "type": "image_transformer_v2", "input_channels": 1, "input_size": [28, 28], "patch_size": [4, 4], - "width": 256, - "depth": 8, + "depths": [8], + "widths": [256], "loss_config": "karras", "loss_weighting": "soft-min-snr", "dropout_rate": 0.05, diff --git a/configs/config_oxford_flowers.json b/configs/config_oxford_flowers.json new file mode 100644 index 0000000..bae25d7 --- /dev/null +++ b/configs/config_oxford_flowers.json @@ -0,0 +1,46 @@ +{ + "model": { + "type": "image_transformer_v2", + "input_channels": 3, + "input_size": [256, 256], + "patch_size": [4, 4], + "depths": [2, 2, 4], + "widths": [128, 256, 512], + "self_attns": [ + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "neighborhood", "d_head": 64, "kernel_size": 7}, + {"type": "global", "d_head": 64} + ], + "loss_config": "karras", + "loss_weighting": "soft-min-snr", + "dropout_rate": 0.1, + "augment_prob": 0.0, + "sigma_data": 0.5, + "sigma_min": 1e-2, + "sigma_max": 160, + "sigma_sample_density": { + "type": "cosine-interpolated" + } + }, + "dataset": { + "type": "huggingface", + "location": "nelorth/oxford-flowers", + "image_key": "image" + }, + "optimizer": { + "type": "adamw", + "lr": 5e-4, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 1e-3 + }, + "lr_sched": { + "type": "constant", + "warmup": 0.0 + }, + "ema_sched": { + "type": "inverse", + "power": 0.75, + "max_value": 0.9999 + } +} diff --git a/configs/config_oxford_flowers_shifted_window.json b/configs/config_oxford_flowers_shifted_window.json new file mode 100644 index 0000000..c292f9b --- /dev/null +++ b/configs/config_oxford_flowers_shifted_window.json @@ -0,0 +1,46 @@ +{ + "model": { + "type": "image_transformer_v2", + "input_channels": 3, + "input_size": [256, 256], + "patch_size": [4, 4], + "depths": [2, 2, 4], + "widths": [128, 256, 512], + "self_attns": [ + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "shifted-window", "d_head": 64, "window_size": 8}, + {"type": "global", "d_head": 64} + ], + "loss_config": "karras", + "loss_weighting": "soft-min-snr", + "dropout_rate": 0.1, + "augment_prob": 0.0, + "sigma_data": 0.5, + "sigma_min": 1e-2, + "sigma_max": 160, + "sigma_sample_density": { + "type": "cosine-interpolated" + } + }, + "dataset": { + "type": "huggingface", + "location": "nelorth/oxford-flowers", + "image_key": "image" + }, + "optimizer": { + "type": "adamw", + "lr": 5e-4, + "betas": [0.9, 0.95], + "eps": 1e-8, + "weight_decay": 1e-3 + }, + "lr_sched": { + "type": "constant", + "warmup": 0.0 + }, + "ema_sched": { + "type": "inverse", + "power": 0.75, + "max_value": 0.9999 + } +} From 71bc989cda99aa3d6d6e10e7ecb159753fad79b9 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 30 Sep 2023 22:27:23 +0000 Subject: [PATCH 23/43] Don't require pytorch nightly for shifted window attention --- k_diffusion/models/image_transformer_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index e191613..e7d9f69 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -311,7 +311,7 @@ def apply_window_attention(window_size, window_shift, q, k, v): mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) # do the attention here - qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=1.0) + qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask) # unwindow qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head)) @@ -432,7 +432,7 @@ def forward(self, x, pos, cond): x = self.norm(x, cond) qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) cos, sin = self.pos_emb(pos) q = apply_rotary_emb_(q, cos, sin, conj=True) k = apply_rotary_emb_(k, cos, sin, conj=True) From b4f9c6994231d1293243c1e43efa1448d5d4020e Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 30 Sep 2023 23:01:06 +0000 Subject: [PATCH 24/43] Note that you have to install Hugging Face datasets for Oxford Flowers --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a88b71..422bc78 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,13 @@ Also, you should make sure your PyTorch installation is capable of using `torch. #### Demo -To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, run: +To train a 256x256 RGB model on [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers) without installing custom CUDA kernels, install [Hugging Face Datasets](https://huggingface.co/docs/datasets/index): + +```sh +pip install datasets +``` + +and run: ```sh python train.py --config configs/config_oxford_flowers_shifted_window.json --name flowers_demo_001 --evaluate-n 0 --batch-size 32 --sample-n 36 --mixed-precision bf16 From 31d4bd271adf135f56d3c3f21740521a89e1e483 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sun, 1 Oct 2023 21:13:16 +0000 Subject: [PATCH 25/43] Set default mapping network width to 256, allow configuring mapping network --- k_diffusion/config.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/k_diffusion/config.py b/k_diffusion/config.py index 96d4622..bd8e3b7 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -57,6 +57,10 @@ def load_config(path_or_dict): } defaults_image_transformer_v2 = { 'model': { + 'mapping_width': 256, + 'mapping_depth': 2, + 'mapping_d_ff': None, + 'mapping_cond_dim': 0, 'd_ffs': None, 'self_attns': None, 'augment_wrapper': False, @@ -119,6 +123,8 @@ def load_config(path_or_dict): config['model']['d_ff'] = round_to_power_of_two(config['model']['width'] * 8 / 3, tol=0.05) elif config['model']['type'] == 'image_transformer_v2': config = merge(defaults_image_transformer_v2, config) + if not config['model']['mapping_d_ff']: + config['model']['mapping_d_ff'] = config['model']['mapping_width'] * 3 if not config['model']['d_ffs']: d_ffs = [] for width in config['model']['widths']: @@ -185,7 +191,7 @@ def make_model(config): else: raise ValueError(f'unsupported self attention type {self_attn["type"]}') levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn)) - mapping = models.image_transformer_v2.MappingSpec(2, config['widths'][-1], config['d_ffs'][-1]) + mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff']) model = models.ImageTransformerDenoiserModelV2( levels=levels, mapping=mapping, @@ -193,6 +199,7 @@ def make_model(config): out_channels=config['input_channels'], patch_size=config['patch_size'], num_classes=num_classes + 1 if num_classes else 0, + mapping_cond_dim=config['mapping_cond_dim'], dropout=config['dropout_rate'], ) else: From 1b2bc7cecc62da1abb19cd0f06eeb4a75912857f Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Mon, 2 Oct 2023 19:57:36 +0000 Subject: [PATCH 26/43] Different rotary freqs per head --- k_diffusion/models/image_transformer_v2.py | 41 ++++++++++++---------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index e7d9f69..5f5548c 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -184,6 +184,7 @@ def apply_rotary_emb(x, cos, sin, conj=False): return torch.cat((y1, y2, x3), dim=-1).to(out_dtype) +@compile def _apply_rotary_emb_inplace(x, cos, sin, conj): d = cos.shape[-1] assert d * 2 <= x.shape[-1] @@ -216,17 +217,19 @@ def apply_rotary_emb_(x, cos, sin, conj=False): class AxialRoPE(nn.Module): - def __init__(self, dim): + def __init__(self, dim, n_heads): super().__init__() - freqs = torch.linspace(math.log(math.pi), math.log(10.0 * math.pi), dim // 4 + 1)[:-1].exp() - self.register_buffer("freqs", freqs) + log_min = math.log(math.pi) + log_max = math.log(10.0 * math.pi) + freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp() + self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous()) def extra_repr(self): - return f"dim={self.freqs.shape[-1] * 4}" + return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" def forward(self, pos): - freqs_h = pos[..., 0:1] * self.freqs.to(pos.dtype) - freqs_w = pos[..., 1:2] * self.freqs.to(pos.dtype) + freqs_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) + freqs_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) freqs = torch.cat((freqs_h, freqs_w), dim=-1) return freqs.cos(), freqs.sin() @@ -341,7 +344,7 @@ def __init__(self, d_model, d_head, cond_features, dropout=0.0): self.norm = AdaRMSNorm(d_model, cond_features) self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) - self.pos_emb = AxialRoPE(d_head // 2) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) @@ -357,16 +360,16 @@ def forward(self, x, pos, cond): if use_flash_2(qkv): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - cos = torch.stack((cos, cos, torch.ones_like(cos)), dim=-2).unsqueeze(-2) - sin = torch.stack((sin, sin, torch.zeros_like(sin)), dim=-2).unsqueeze(-2) - qkv = apply_rotary_emb_(qkv, cos, sin, conj=True) + cos = torch.stack((cos, cos, torch.ones_like(cos)), dim=-3) + sin = torch.stack((sin, sin, torch.zeros_like(sin)), dim=-3) + qkv = apply_rotary_emb_(qkv, cos, sin) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - q = apply_rotary_emb_(q, cos, sin, conj=True) - k = apply_rotary_emb_(k, cos, sin, conj=True) + q = apply_rotary_emb_(q, cos, sin) + k = apply_rotary_emb_(k, cos, sin) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -383,7 +386,7 @@ def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): self.norm = AdaRMSNorm(d_model, cond_features) self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) - self.pos_emb = AxialRoPE(d_head // 2) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) @@ -397,8 +400,9 @@ def forward(self, x, pos, cond): q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) cos, sin = self.pos_emb(pos) - q = apply_rotary_emb_(q, cos, sin, conj=True) - k = apply_rotary_emb_(k, cos, sin, conj=True) + cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) + q = apply_rotary_emb_(q, cos, sin) + k = apply_rotary_emb_(k, cos, sin) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) @@ -420,7 +424,7 @@ def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dr self.norm = AdaRMSNorm(d_model, cond_features) self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) - self.pos_emb = AxialRoPE(d_head // 2) + self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) @@ -434,8 +438,9 @@ def forward(self, x, pos, cond): q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) cos, sin = self.pos_emb(pos) - q = apply_rotary_emb_(q, cos, sin, conj=True) - k = apply_rotary_emb_(k, cos, sin, conj=True) + cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) + q = apply_rotary_emb_(q, cos, sin) + k = apply_rotary_emb_(k, cos, sin) x = apply_window_attention(self.window_size, self.window_shift, q, k, v) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) From 0e6bb3969c3d73e10d118c9e2deb0144b2906083 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 3 Oct 2023 13:56:23 +0000 Subject: [PATCH 27/43] Fix rotary for non-flash attention --- k_diffusion/models/image_transformer_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 5f5548c..87de77c 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -368,6 +368,7 @@ def forward(self, x, pos, cond): else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) + cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) q = apply_rotary_emb_(q, cos, sin) k = apply_rotary_emb_(k, cos, sin) x = F.scaled_dot_product_attention(q, k, v) From fde0b4e263e10488efcd90d0b4f5eb54de9d24d7 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 3 Oct 2023 20:30:07 +0000 Subject: [PATCH 28/43] Change cos/sin axes to index from right consistently --- k_diffusion/models/image_transformer_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 87de77c..8fad967 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -368,7 +368,7 @@ def forward(self, x, pos, cond): else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) + cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) q = apply_rotary_emb_(q, cos, sin) k = apply_rotary_emb_(k, cos, sin) x = F.scaled_dot_product_attention(q, k, v) @@ -401,7 +401,7 @@ def forward(self, x, pos, cond): q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) cos, sin = self.pos_emb(pos) - cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) + cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) q = apply_rotary_emb_(q, cos, sin) k = apply_rotary_emb_(k, cos, sin) if natten is None: @@ -439,7 +439,7 @@ def forward(self, x, pos, cond): q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) cos, sin = self.pos_emb(pos) - cos, sin = cos.movedim(-2, 0), sin.movedim(-2, 0) + cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) q = apply_rotary_emb_(q, cos, sin) k = apply_rotary_emb_(k, cos, sin) x = apply_window_attention(self.window_size, self.window_shift, q, k, v) From e6760d3423fa707f2d452c1d84b032f17861c3e5 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 3 Oct 2023 20:51:53 +0000 Subject: [PATCH 29/43] Use lru_cache on make_shifted_window_masks() --- k_diffusion/models/image_transformer_v2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 8fad967..d146af7 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -1,7 +1,7 @@ """k-diffusion transformer diffusion models, version 2.""" from dataclasses import dataclass -from functools import reduce +from functools import lru_cache, reduce import math from typing import Union @@ -268,6 +268,7 @@ def shifted_unwindow(window_shift, x): return x +@lru_cache def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None): ph_coords = torch.arange(n_h_w, device=device) pw_coords = torch.arange(n_w_w, device=device) From 46bd364bb4a41e5d3ae7a367ba7018271df98812 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 4 Oct 2023 05:30:49 +0000 Subject: [PATCH 30/43] Simplify apply_rotary_emb() etc --- k_diffusion/models/image_transformer_v2.py | 64 +++++++++++----------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index d146af7..8a668c2 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -171,13 +171,14 @@ def forward(self, x, cond): # Rotary position embeddings @compile -def apply_rotary_emb(x, cos, sin, conj=False): +def apply_rotary_emb(x, angle, conj=False): out_dtype = x.dtype - dtype = reduce(torch.promote_types, (x.dtype, cos.dtype, sin.dtype, torch.float32)) - d = cos.shape[-1] + dtype = reduce(torch.promote_types, (x.dtype, angle.dtype, torch.float32)) + d = angle.shape[-1] assert d * 2 <= x.shape[-1] x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] - x1, x2, cos, sin = x1.to(dtype), x2.to(dtype), cos.to(dtype), sin.to(dtype) + x1, x2, angle = x1.to(dtype), x2.to(dtype), angle.to(dtype) + cos, sin = torch.cos(angle), torch.sin(angle) sin = -sin if conj else sin y1 = (x1 * cos - x2 * sin).to(out_dtype) y2 = (x2 * cos + x1 * sin).to(out_dtype) @@ -185,11 +186,12 @@ def apply_rotary_emb(x, cos, sin, conj=False): @compile -def _apply_rotary_emb_inplace(x, cos, sin, conj): - d = cos.shape[-1] +def _apply_rotary_emb_inplace(x, angle, conj): + d = angle.shape[-1] assert d * 2 <= x.shape[-1] x1, x2 = x[..., :d], x[..., d : d * 2] tmp = x1.clone() + cos, sin = torch.cos(angle), torch.sin(angle) x1.mul_(cos).addcmul_(x2, sin, value=1 if conj else -1) x2.mul_(cos).addcmul_(tmp, sin, value=-1 if conj else 1) return x @@ -197,23 +199,23 @@ def _apply_rotary_emb_inplace(x, cos, sin, conj): class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod - def forward(x, cos, sin, conj): - return _apply_rotary_emb_inplace(x, cos, sin, conj=conj) + def forward(x, angle, conj): + return _apply_rotary_emb_inplace(x, angle, conj=conj) @staticmethod def setup_context(ctx, inputs, output): - _, cos, sin, conj = inputs - ctx.save_for_backward(cos, sin) + _, angle, conj = inputs + ctx.save_for_backward(angle) ctx.conj = conj @staticmethod def backward(ctx, grad_output): - cos, sin = ctx.saved_tensors - return _apply_rotary_emb_inplace(grad_output, cos, sin, conj=not ctx.conj), None, None, None + angle, = ctx.saved_tensors + return _apply_rotary_emb_inplace(grad_output, angle, conj=not ctx.conj), None, None -def apply_rotary_emb_(x, cos, sin, conj=False): - return ApplyRotaryEmbeddingInplace.apply(x, cos, sin, conj) +def apply_rotary_emb_(x, angle, conj=False): + return ApplyRotaryEmbeddingInplace.apply(x, angle, conj) class AxialRoPE(nn.Module): @@ -228,10 +230,9 @@ def extra_repr(self): return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" def forward(self, pos): - freqs_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) - freqs_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) - freqs = torch.cat((freqs_h, freqs_w), dim=-1) - return freqs.cos(), freqs.sin() + angle_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) + angle_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) + return torch.cat((angle_h, angle_w), dim=-1) # Shifted window attention @@ -357,21 +358,20 @@ def forward(self, x, pos, cond): x = self.norm(x, cond) qkv = self.qkv_proj(x) pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) - cos, sin = self.pos_emb(pos) + angle = self.pos_emb(pos) if use_flash_2(qkv): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - cos = torch.stack((cos, cos, torch.ones_like(cos)), dim=-3) - sin = torch.stack((sin, sin, torch.zeros_like(sin)), dim=-3) - qkv = apply_rotary_emb_(qkv, cos, sin) + angle = torch.stack((angle, angle, torch.zeros_like(angle)), dim=-3) + qkv = apply_rotary_emb_(qkv, angle) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) - q = apply_rotary_emb_(q, cos, sin) - k = apply_rotary_emb_(k, cos, sin) + angle = angle.movedim(-2, -4) + q = apply_rotary_emb_(q, angle) + k = apply_rotary_emb_(k, angle) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -401,10 +401,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) - cos, sin = self.pos_emb(pos) - cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) - q = apply_rotary_emb_(q, cos, sin) - k = apply_rotary_emb_(k, cos, sin) + angle = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, angle) + k = apply_rotary_emb_(k, angle) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) @@ -439,10 +438,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) - cos, sin = self.pos_emb(pos) - cos, sin = cos.movedim(-2, -4), sin.movedim(-2, -4) - q = apply_rotary_emb_(q, cos, sin) - k = apply_rotary_emb_(k, cos, sin) + angle = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, angle) + k = apply_rotary_emb_(k, angle) x = apply_window_attention(self.window_size, self.window_shift, q, k, v) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) From b1748a8c7957dc1b60e84442a9e0ee820ca1fa9a Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 4 Oct 2023 14:07:15 +0000 Subject: [PATCH 31/43] Optimize apply_rotary_emb() further --- k_diffusion/models/image_transformer_v2.py | 79 ++++++++++++---------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 8a668c2..34f7e00 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -171,51 +171,56 @@ def forward(self, x, cond): # Rotary position embeddings @compile -def apply_rotary_emb(x, angle, conj=False): +def apply_rotary_emb(x, imag, conj=False): out_dtype = x.dtype - dtype = reduce(torch.promote_types, (x.dtype, angle.dtype, torch.float32)) - d = angle.shape[-1] + dtype = reduce(torch.promote_types, (x.dtype, imag.dtype, torch.float32)) + d = imag.shape[-1] assert d * 2 <= x.shape[-1] x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] - x1, x2, angle = x1.to(dtype), x2.to(dtype), angle.to(dtype) - cos, sin = torch.cos(angle), torch.sin(angle) + x1, x2, imag = x1.to(dtype), x2.to(dtype), imag.to(dtype) + cos, sin = torch.cos(imag), torch.sin(imag) sin = -sin if conj else sin - y1 = (x1 * cos - x2 * sin).to(out_dtype) - y2 = (x2 * cos + x1 * sin).to(out_dtype) - return torch.cat((y1, y2, x3), dim=-1).to(out_dtype) + y1 = x1 * cos - x2 * sin + y2 = x2 * cos + x1 * sin + y1, y2 = y1.to(out_dtype), y2.to(out_dtype) + return torch.cat((y1, y2, x3), dim=-1) @compile -def _apply_rotary_emb_inplace(x, angle, conj): - d = angle.shape[-1] +def _apply_rotary_emb_inplace(x, imag, conj): + dtype = reduce(torch.promote_types, (x.dtype, imag.dtype, torch.float32)) + d = imag.shape[-1] assert d * 2 <= x.shape[-1] x1, x2 = x[..., :d], x[..., d : d * 2] - tmp = x1.clone() - cos, sin = torch.cos(angle), torch.sin(angle) - x1.mul_(cos).addcmul_(x2, sin, value=1 if conj else -1) - x2.mul_(cos).addcmul_(tmp, sin, value=-1 if conj else 1) + x1_, x2_, imag = x1.to(dtype), x2.to(dtype), imag.to(dtype) + cos, sin = torch.cos(imag), torch.sin(imag) + sin = -sin if conj else sin + y1 = x1_ * cos - x2_ * sin + y2 = x2_ * cos + x1_ * sin + x1.copy_(y1) + x2.copy_(y2) return x class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod - def forward(x, angle, conj): - return _apply_rotary_emb_inplace(x, angle, conj=conj) + def forward(x, imag, conj): + return _apply_rotary_emb_inplace(x, imag, conj=conj) @staticmethod def setup_context(ctx, inputs, output): - _, angle, conj = inputs - ctx.save_for_backward(angle) + _, imag, conj = inputs + ctx.save_for_backward(imag) ctx.conj = conj @staticmethod def backward(ctx, grad_output): - angle, = ctx.saved_tensors - return _apply_rotary_emb_inplace(grad_output, angle, conj=not ctx.conj), None, None + imag, = ctx.saved_tensors + return _apply_rotary_emb_inplace(grad_output, imag, conj=not ctx.conj), None, None -def apply_rotary_emb_(x, angle, conj=False): - return ApplyRotaryEmbeddingInplace.apply(x, angle, conj) +def apply_rotary_emb_(x, imag, conj=False): + return ApplyRotaryEmbeddingInplace.apply(x, imag, conj) class AxialRoPE(nn.Module): @@ -230,9 +235,9 @@ def extra_repr(self): return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" def forward(self, pos): - angle_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) - angle_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) - return torch.cat((angle_h, angle_w), dim=-1) + imag_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) + imag_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) + return torch.cat((imag_h, imag_w), dim=-1) # Shifted window attention @@ -358,20 +363,20 @@ def forward(self, x, pos, cond): x = self.norm(x, cond) qkv = self.qkv_proj(x) pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) - angle = self.pos_emb(pos) + imag = self.pos_emb(pos) if use_flash_2(qkv): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - angle = torch.stack((angle, angle, torch.zeros_like(angle)), dim=-3) - qkv = apply_rotary_emb_(qkv, angle) + imag = torch.stack((imag, imag, torch.zeros_like(imag)), dim=-3) + qkv = apply_rotary_emb_(qkv, imag) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - angle = angle.movedim(-2, -4) - q = apply_rotary_emb_(q, angle) - k = apply_rotary_emb_(k, angle) + imag = imag.movedim(-2, -4) + q = apply_rotary_emb_(q, imag) + k = apply_rotary_emb_(k, imag) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -401,9 +406,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) - angle = self.pos_emb(pos).movedim(-2, -4) - q = apply_rotary_emb_(q, angle) - k = apply_rotary_emb_(k, angle) + imag = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, imag) + k = apply_rotary_emb_(k, imag) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) @@ -438,9 +443,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) - angle = self.pos_emb(pos).movedim(-2, -4) - q = apply_rotary_emb_(q, angle) - k = apply_rotary_emb_(k, angle) + imag = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, imag) + k = apply_rotary_emb_(k, imag) x = apply_window_attention(self.window_size, self.window_shift, q, k, v) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) From 2cf5e3e3c3d38255827d1ff9f3d6acf1ab1492ac Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 4 Oct 2023 21:52:36 +0000 Subject: [PATCH 32/43] imag -> theta --- k_diffusion/models/image_transformer_v2.py | 66 +++++++++++----------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 34f7e00..230b39e 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -171,14 +171,14 @@ def forward(self, x, cond): # Rotary position embeddings @compile -def apply_rotary_emb(x, imag, conj=False): +def apply_rotary_emb(x, theta, conj=False): out_dtype = x.dtype - dtype = reduce(torch.promote_types, (x.dtype, imag.dtype, torch.float32)) - d = imag.shape[-1] + dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) + d = theta.shape[-1] assert d * 2 <= x.shape[-1] x1, x2, x3 = x[..., :d], x[..., d : d * 2], x[..., d * 2 :] - x1, x2, imag = x1.to(dtype), x2.to(dtype), imag.to(dtype) - cos, sin = torch.cos(imag), torch.sin(imag) + x1, x2, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) + cos, sin = torch.cos(theta), torch.sin(theta) sin = -sin if conj else sin y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin @@ -187,13 +187,13 @@ def apply_rotary_emb(x, imag, conj=False): @compile -def _apply_rotary_emb_inplace(x, imag, conj): - dtype = reduce(torch.promote_types, (x.dtype, imag.dtype, torch.float32)) - d = imag.shape[-1] +def _apply_rotary_emb_inplace(x, theta, conj): + dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) + d = theta.shape[-1] assert d * 2 <= x.shape[-1] x1, x2 = x[..., :d], x[..., d : d * 2] - x1_, x2_, imag = x1.to(dtype), x2.to(dtype), imag.to(dtype) - cos, sin = torch.cos(imag), torch.sin(imag) + x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype) + cos, sin = torch.cos(theta), torch.sin(theta) sin = -sin if conj else sin y1 = x1_ * cos - x2_ * sin y2 = x2_ * cos + x1_ * sin @@ -204,23 +204,23 @@ def _apply_rotary_emb_inplace(x, imag, conj): class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod - def forward(x, imag, conj): - return _apply_rotary_emb_inplace(x, imag, conj=conj) + def forward(x, theta, conj): + return _apply_rotary_emb_inplace(x, theta, conj=conj) @staticmethod def setup_context(ctx, inputs, output): - _, imag, conj = inputs - ctx.save_for_backward(imag) + _, theta, conj = inputs + ctx.save_for_backward(theta) ctx.conj = conj @staticmethod def backward(ctx, grad_output): - imag, = ctx.saved_tensors - return _apply_rotary_emb_inplace(grad_output, imag, conj=not ctx.conj), None, None + theta, = ctx.saved_tensors + return _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj), None, None -def apply_rotary_emb_(x, imag, conj=False): - return ApplyRotaryEmbeddingInplace.apply(x, imag, conj) +def apply_rotary_emb_(x, theta, conj=False): + return ApplyRotaryEmbeddingInplace.apply(x, theta, conj) class AxialRoPE(nn.Module): @@ -235,9 +235,9 @@ def extra_repr(self): return f"dim={self.freqs.shape[1] * 4}, n_heads={self.freqs.shape[0]}" def forward(self, pos): - imag_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) - imag_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) - return torch.cat((imag_h, imag_w), dim=-1) + theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) + theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) + return torch.cat((theta_h, theta_w), dim=-1) # Shifted window attention @@ -363,20 +363,20 @@ def forward(self, x, pos, cond): x = self.norm(x, cond) qkv = self.qkv_proj(x) pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) - imag = self.pos_emb(pos) + theta = self.pos_emb(pos) if use_flash_2(qkv): qkv = rearrange(qkv, "n h w (t nh e) -> n (h w) t nh e", t=3, e=self.d_head) qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) - imag = torch.stack((imag, imag, torch.zeros_like(imag)), dim=-3) - qkv = apply_rotary_emb_(qkv, imag) + theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3) + qkv = apply_rotary_emb_(qkv, theta) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - imag = imag.movedim(-2, -4) - q = apply_rotary_emb_(q, imag) - k = apply_rotary_emb_(k, imag) + theta = theta.movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) x = F.scaled_dot_product_attention(q, k, v) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -406,9 +406,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) - imag = self.pos_emb(pos).movedim(-2, -4) - q = apply_rotary_emb_(q, imag) - k = apply_rotary_emb_(k, imag) + theta = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) @@ -443,9 +443,9 @@ def forward(self, x, pos, cond): qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) - imag = self.pos_emb(pos).movedim(-2, -4) - q = apply_rotary_emb_(q, imag) - k = apply_rotary_emb_(k, imag) + theta = self.pos_emb(pos).movedim(-2, -4) + q = apply_rotary_emb_(q, theta) + k = apply_rotary_emb_(k, theta) x = apply_window_attention(self.window_size, self.window_shift, q, k, v) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) From 6b745a528554731867842a1b848e9b2700c0eaf7 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Wed, 4 Oct 2023 22:18:16 +0000 Subject: [PATCH 33/43] Require PyTorch 2.1 --- k_diffusion/models/image_transformer_v2.py | 14 +++++++------- requirements.txt | 3 +-- setup.cfg | 5 ++--- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 230b39e..3fa095f 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -308,7 +308,7 @@ def make_shifted_window_masks(n_h_w, n_w_w, w_h, w_w, shift, device=None): return m -def apply_window_attention(window_size, window_shift, q, k, v): +def apply_window_attention(window_size, window_shift, q, k, v, scale=None): # prep windows and masks q_windows = shifted_window(window_size, window_shift, q) k_windows = shifted_window(window_size, window_shift, k) @@ -321,7 +321,7 @@ def apply_window_attention(window_size, window_shift, q, k, v): mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) # do the attention here - qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask) + qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale) # unwindow qkv = torch.reshape(qkv, (b, heads, h, w, wh, ww, d_head)) @@ -373,11 +373,11 @@ def forward(self, x, pos, cond): x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None] * k.shape[-1]**0.5, 1e-6) - theta = theta.movedim(-2, -4) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6) + theta = theta.movedim(-2, -3) q = apply_rotary_emb_(q, theta) k = apply_rotary_emb_(k, theta) - x = F.scaled_dot_product_attention(q, k, v) + x = F.scaled_dot_product_attention(q, k, v, scale=1.0) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) x = self.out_proj(x) @@ -442,11 +442,11 @@ def forward(self, x, pos, cond): x = self.norm(x, cond) qkv = self.qkv_proj(x) q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head) - q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None] * k.shape[-1]**0.5, 1e-6) + q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6) theta = self.pos_emb(pos).movedim(-2, -4) q = apply_rotary_emb_(q, theta) k = apply_rotary_emb_(k, theta) - x = apply_window_attention(self.window_size, self.window_shift, q, k, v) + x = apply_window_attention(self.window_size, self.window_shift, q, k, v, scale=1.0) x = rearrange(x, "n nh h w e -> n h w (nh e)") x = self.dropout(x) x = self.out_proj(x) diff --git a/requirements.txt b/requirements.txt index a68feb8..c62792c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,11 +6,10 @@ einops jsonmerge kornia Pillow -rotary-embedding-torch safetensors scikit-image scipy -torch>=2.0 +torch>=2.1 torchdiffeq torchsde torchvision diff --git a/setup.cfg b/setup.cfg index 7ec1c5f..1e73b1f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = k-diffusion -version = 0.1.0 +version = 0.2.0.dev0 author = Katherine Crowson author_email = crowsonkb@gmail.com url = https://github.com/crowsonkb/k-diffusion @@ -20,11 +20,10 @@ install_requires = jsonmerge kornia Pillow - rotary-embedding-torch safetensors scikit-image scipy - torch >= 2.0 + torch >= 2.1 torchdiffeq torchsde torchvision From 9a6e9f373c7a7644fa9cedd926b3f03da7d7942e Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 5 Oct 2023 01:01:11 +0000 Subject: [PATCH 34/43] Add 8-bit Adam --- train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/train.py b/train.py index 6bcab1c..e7253f0 100755 --- a/train.py +++ b/train.py @@ -155,6 +155,13 @@ def main(): betas=tuple(opt_config['betas']), eps=opt_config['eps'], weight_decay=opt_config['weight_decay']) + elif opt_config['type'] == 'adam8bit': + import bitsandbytes as bnb + opt = bnb.optim.Adam8bit(groups, + lr=lr, + betas=tuple(opt_config['betas']), + eps=opt_config['eps'], + weight_decay=opt_config['weight_decay']) elif opt_config['type'] == 'sgd': opt = optim.SGD(groups, lr=lr, From de9bb2d1cc9b575c524c7fab6ee0a616870ef44b Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Thu, 5 Oct 2023 16:59:44 +0000 Subject: [PATCH 35/43] Avoid an extra allocation in apply_rotary_emb_() --- k_diffusion/models/image_transformer_v2.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 3fa095f..1b4a142 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -199,13 +199,13 @@ def _apply_rotary_emb_inplace(x, theta, conj): y2 = x2_ * cos + x1_ * sin x1.copy_(y1) x2.copy_(y2) - return x class ApplyRotaryEmbeddingInplace(torch.autograd.Function): @staticmethod def forward(x, theta, conj): - return _apply_rotary_emb_inplace(x, theta, conj=conj) + _apply_rotary_emb_inplace(x, theta, conj=conj) + return x @staticmethod def setup_context(ctx, inputs, output): @@ -216,11 +216,12 @@ def setup_context(ctx, inputs, output): @staticmethod def backward(ctx, grad_output): theta, = ctx.saved_tensors - return _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj), None, None + _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj) + return grad_output, None, None -def apply_rotary_emb_(x, theta, conj=False): - return ApplyRotaryEmbeddingInplace.apply(x, theta, conj) +def apply_rotary_emb_(x, theta): + return ApplyRotaryEmbeddingInplace.apply(x, theta, False) class AxialRoPE(nn.Module): From 134015491a9fb2eb98e770f7d454d75c094fd436 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Fri, 6 Oct 2023 23:30:31 +0000 Subject: [PATCH 36/43] Default to no checkpointing if not in the checkpointing context manager --- k_diffusion/models/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/k_diffusion/models/flags.py b/k_diffusion/models/flags.py index 8b1e4bd..1b35268 100644 --- a/k_diffusion/models/flags.py +++ b/k_diffusion/models/flags.py @@ -25,4 +25,4 @@ def checkpointing(enable=True): def get_checkpointing(): - return state.checkpointing + return getattr(state, "checkpointing", False) From ccd801fb8a625d8a715476043c10d379b86c16a0 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sun, 8 Oct 2023 16:15:40 +0000 Subject: [PATCH 37/43] Fix wall clock time counter --- k_diffusion/utils.py | 27 --------------------------- train.py | 28 +++++++++++++++++++++------- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/k_diffusion/utils.py b/k_diffusion/utils.py index f14c643..946f2da 100644 --- a/k_diffusion/utils.py +++ b/k_diffusion/utils.py @@ -456,30 +456,3 @@ def ema_update_dict(values, updates, decay): values[k] *= decay values[k] += (1 - decay) * v return values - - -class Timer: - def __init__(self, elapsed=0.0): - """A simple counter of elapsed time.""" - self.elapsed = elapsed - self.last_time = None - - def get(self, time_value=None): - """Updates and returns the elapsed time.""" - time_value = time_value or time.time() - if self.last_time: - self.elapsed += time_value - self.last_time - self.last_time = time_value - return self.elapsed - - def start(self, time_value=None): - """Starts counting elapsed time.""" - time_value = time_value or time.time() - self.get(time_value) - self.last_time = time_value - - def stop(self, time_value=None): - """Stops counting elapsed time.""" - time_value = time_value or time.time() - self.get(time_value) - self.last_time = None diff --git a/train.py b/train.py index e7253f0..0b627aa 100755 --- a/train.py +++ b/train.py @@ -9,6 +9,7 @@ import math import json from pathlib import Path +import time import accelerate import safetensors.torch as safetorch @@ -126,7 +127,7 @@ def main(): seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes], generator=torch.Generator().manual_seed(args.seed)) torch.manual_seed(seeds[accelerator.process_index]) demo_gen = torch.Generator().manual_seed(torch.randint(-2 ** 63, 2 ** 63 - 1, ()).item()) - timer = K.utils.Timer() + elapsed = 0.0 inner_model = K.config.make_model(config) inner_model_ema = deepcopy(inner_model) @@ -233,7 +234,7 @@ def main(): class_key = dataset_config.get('class_key', 1) train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, - num_workers=args.num_workers, persistent_workers=True) + num_workers=args.num_workers, persistent_workers=True, pin_memory=True) inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) if use_wandb: @@ -274,7 +275,7 @@ def main(): if args.gns and ckpt.get('gns_stats', None) is not None: gns_stats.load_state_dict(ckpt['gns_stats']) demo_gen.set_state(ckpt['demo_gen']) - timer = K.utils.Timer(ckpt.get('elapsed', 0.0)) + elapsed = ckpt.get('elapsed', 0.0) del ckpt else: @@ -375,7 +376,7 @@ def sample_fn(n): kid = K.evaluation.kid(fakes_features, reals_features) print(f'FID: {fid.item():g}, KID: {kid.item():g}') if accelerator.is_main_process: - metrics_log.write(step, timer.get(), ema_stats['loss'], fid.item(), kid.item()) + metrics_log.write(step, elapsed, ema_stats['loss'], fid.item(), kid.item()) if use_wandb: wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) @@ -398,7 +399,7 @@ def save(): 'gns_stats': gns_stats.state_dict() if gns_stats is not None else None, 'ema_stats': ema_stats, 'demo_gen': demo_gen.get_state(), - 'elapsed': timer.get(), + 'elapsed': elapsed, } accelerator.save(obj, filename) if accelerator.is_main_process: @@ -412,7 +413,14 @@ def save(): try: while True: for batch in tqdm(train_dl, smoothing=0.1, disable=not accelerator.is_main_process): - timer.start() + if device.type == 'cuda': + start_timer = torch.cuda.Event(enable_timing=True) + end_timer = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_timer.record() + else: + start_timer = time.time() + with accelerator.accumulate(model): reals, _, aug_cond = batch[image_key] class_cond, extra_args = None, {} @@ -444,6 +452,13 @@ def save(): K.utils.ema_update(model, model_ema, ema_decay) ema_sched.step() + if device.type == 'cuda': + end_timer.record() + torch.cuda.synchronize() + elapsed += start_timer.elapsed_time(end_timer) / 1000 + else: + elapsed += time.time() - start_timer + if step % 25 == 0: loss_disp = sum(losses_since_last_print) / len(losses_since_last_print) losses_since_last_print.clear() @@ -466,7 +481,6 @@ def save(): wandb.log(log_dict, step=step) step += 1 - timer.stop() if step % args.demo_every == 0: demo() From 8907fea5906b71887a859bffdef05041e4d1b142 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Mon, 9 Oct 2023 18:24:23 +0000 Subject: [PATCH 38/43] Don't spawn a bunch of processes on import --- k_diffusion/models/axial_rope.py | 11 ++------- k_diffusion/models/flags.py | 28 ++++++++++++++++++++++ k_diffusion/models/image_transformer_v1.py | 16 ++++--------- k_diffusion/models/image_transformer_v2.py | 21 +++++----------- 4 files changed, 40 insertions(+), 36 deletions(-) diff --git a/k_diffusion/models/axial_rope.py b/k_diffusion/models/axial_rope.py index ba105cb..cadc3b6 100644 --- a/k_diffusion/models/axial_rope.py +++ b/k_diffusion/models/axial_rope.py @@ -17,7 +17,8 @@ def rotate_half(x): return x.view(*shape, d * r) -def _apply_rotary_emb(freqs, t, start_index=0, scale=1.0): +@flags.compile_wrap +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): freqs = freqs.to(t) rot_dim = freqs.shape[-1] end_index = start_index + rot_dim @@ -27,14 +28,6 @@ def _apply_rotary_emb(freqs, t, start_index=0, scale=1.0): return torch.cat((t_left, t, t_right), dim=-1) -try: - if not flags.get_use_compile(): - raise RuntimeError - apply_rotary_emb = torch.compile(_apply_rotary_emb) -except RuntimeError: - apply_rotary_emb = _apply_rotary_emb - - def centers(start, stop, num, dtype=None, device=None): edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) return (edges[:-1] + edges[1:]) / 2 diff --git a/k_diffusion/models/flags.py b/k_diffusion/models/flags.py index 1b35268..6d555b7 100644 --- a/k_diffusion/models/flags.py +++ b/k_diffusion/models/flags.py @@ -1,7 +1,10 @@ from contextlib import contextmanager +from functools import update_wrapper import os import threading +import torch + def get_use_compile(): return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" @@ -26,3 +29,28 @@ def checkpointing(enable=True): def get_checkpointing(): return getattr(state, "checkpointing", False) + + +class compile_wrap: + def __init__(self, function, *args, **kwargs): + self.function = function + self.args = args + self.kwargs = kwargs + self._compiled_function = None + update_wrapper(self, function) + + @property + def compiled_function(self): + if self._compiled_function is not None: + return self._compiled_function + if get_use_compile(): + try: + self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) + except RuntimeError: + self._compiled_function = self.function + else: + self._compiled_function = self.function + return self._compiled_function + + def __call__(self, *args, **kwargs): + return self.compiled_function(*args, **kwargs) diff --git a/k_diffusion/models/image_transformer_v1.py b/k_diffusion/models/image_transformer_v1.py index 65ad797..c37c639 100644 --- a/k_diffusion/models/image_transformer_v1.py +++ b/k_diffusion/models/image_transformer_v1.py @@ -73,28 +73,20 @@ def scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0): return F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p=dropout_p) -def _geglu(x): +@flags.compile_wrap +def geglu(x): a, b = x.chunk(2, dim=-1) return a * F.gelu(b) -def _rms_norm(x, scale, eps): +@flags.compile_wrap +def rms_norm(x, scale, eps): dtype = torch.promote_types(x.dtype, torch.float32) mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) return x * scale.to(x.dtype) -try: - if not flags.get_use_compile(): - raise RuntimeError - geglu = torch.compile(_geglu) - rms_norm = torch.compile(_rms_norm) -except RuntimeError: - geglu = _geglu - rms_norm = _rms_norm - - class GEGLU(nn.Module): def forward(self, x): return geglu(x) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 1b4a142..3582f78 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -86,16 +86,7 @@ def filter_params(function, module): # Kernels -def compile(function, *args, **kwargs): - if not flags.get_use_compile(): - return function - try: - return torch.compile(function, *args, **kwargs) - except RuntimeError: - return function - - -@compile +@flags.compile_wrap def linear_geglu(x, weight, bias=None): x = x @ weight.mT if bias is not None: @@ -104,7 +95,7 @@ def linear_geglu(x, weight, bias=None): return x * F.gelu(gate) -@compile +@flags.compile_wrap def rms_norm(x, scale, eps): dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) @@ -112,7 +103,7 @@ def rms_norm(x, scale, eps): return x * scale.to(x.dtype) -@compile +@flags.compile_wrap def scale_for_cosine_sim(q, k, scale, eps): dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32)) sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True) @@ -123,7 +114,7 @@ def scale_for_cosine_sim(q, k, scale, eps): return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype) -@compile +@flags.compile_wrap def scale_for_cosine_sim_qkv(qkv, scale, eps): q, k, v = qkv.unbind(2) q, k = scale_for_cosine_sim(q, k, scale[:, None], eps) @@ -170,7 +161,7 @@ def forward(self, x, cond): # Rotary position embeddings -@compile +@flags.compile_wrap def apply_rotary_emb(x, theta, conj=False): out_dtype = x.dtype dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) @@ -186,7 +177,7 @@ def apply_rotary_emb(x, theta, conj=False): return torch.cat((y1, y2, x3), dim=-1) -@compile +@flags.compile_wrap def _apply_rotary_emb_inplace(x, theta, conj): dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32)) d = theta.shape[-1] From 962d28bf5306dfb4f36145301066702da38f8124 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Mon, 9 Oct 2023 18:38:53 +0000 Subject: [PATCH 39/43] Weight decay token merge and split projections --- k_diffusion/models/image_transformer_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 3582f78..5f56ed7 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -557,7 +557,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = nn.Linear(in_features * self.h * self.w, out_features, bias=False) + self.proj = apply_wd(nn.Linear(in_features * self.h * self.w, out_features, bias=False)) def forward(self, x): x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w) @@ -569,7 +569,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = nn.Linear(in_features, out_features * self.h * self.w, bias=False) + self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False)) def forward(self, x): x = self.proj(x) @@ -581,7 +581,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = nn.Linear(in_features, out_features * self.h * self.w, bias=False) + self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False)) self.fac = nn.Parameter(torch.ones(1) * 0.5) def forward(self, x, skip): From 11e69027f628128a49b5739e51ac38fb6d21358e Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Mon, 9 Oct 2023 23:19:29 +0000 Subject: [PATCH 40/43] Add --evaluate-only to train.py --- config_from_inference.py | 35 +++++++++++++++++++++++++++++++++++ train.py | 19 ++++++++++++++----- 2 files changed, 49 insertions(+), 5 deletions(-) create mode 100755 config_from_inference.py diff --git a/config_from_inference.py b/config_from_inference.py new file mode 100755 index 0000000..f83b49f --- /dev/null +++ b/config_from_inference.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +"""Extracts the configuration file from a slim inference checkpoint.""" + +import argparse +import json +from pathlib import Path +import sys + +import k_diffusion as K +import safetensors.torch as safetorch + + +def main(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + p.add_argument("checkpoint", type=Path, + help="the inference checkpoint to extract the configuration from") + p.add_argument("--output", "-o", type=Path, + help="the output configuration file") + args = p.parse_args() + + print(f"Loading inference checkpoint {args.checkpoint}...", file=sys.stderr) + metadata = K.utils.get_safetensors_metadata(args.checkpoint) + if "config" not in metadata: + raise ValueError("No configuration found in checkpoint") + + output_path = args.output or args.checkpoint.with_suffix(".json") + + print(f"Saving configuration to {output_path}...", file=sys.stderr) + output_path.write_text(metadata["config"]) + + +if __name__ == "__main__": + main() diff --git a/train.py b/train.py index 0b627aa..6d85de1 100755 --- a/train.py +++ b/train.py @@ -52,12 +52,14 @@ def main(): p.add_argument('--end-step', type=int, default=None, help='the step to end training at') p.add_argument('--evaluate-every', type=int, default=10000, - help='save a demo grid every this many steps') + help='evaluate every this many steps') + p.add_argument('--evaluate-n', type=int, default=2000, + help='the number of samples to draw to evaluate') + p.add_argument('--evaluate-only', action='store_true', + help='evaluate instead of training') p.add_argument('--evaluate-with', type=str, default='inception', choices=['inception', 'clip', 'dinov2'], help='the feature extractor to use for evaluation') - p.add_argument('--evaluate-n', type=int, default=2000, - help='the number of samples to draw to evaluate') p.add_argument('--gns', action='store_true', help='measure the gradient noise scale (DDP only, disables stratified sampling)') p.add_argument('--grad-accum-steps', type=int, default=1, @@ -297,6 +299,7 @@ def main(): del ckpt evaluate_enabled = args.evaluate_every > 0 and args.evaluate_n > 0 + metrics_log = None if evaluate_enabled: if args.evaluate_with == 'inception': extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) @@ -310,7 +313,7 @@ def main(): if accelerator.is_main_process: print('Computing features for reals...') reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[image_key][1], extractor, args.evaluate_n, args.batch_size) - if accelerator.is_main_process: + if accelerator.is_main_process and not args.evaluate_only: metrics_log = K.utils.CSVLogger(f'{args.name}_metrics.csv', ['step', 'time', 'loss', 'fid', 'kid']) del train_iter @@ -375,7 +378,7 @@ def sample_fn(n): fid = K.evaluation.fid(fakes_features, reals_features) kid = K.evaluation.kid(fakes_features, reals_features) print(f'FID: {fid.item():g}, KID: {kid.item():g}') - if accelerator.is_main_process: + if accelerator.is_main_process and metrics_log is not None: metrics_log.write(step, elapsed, ema_stats['loss'], fid.item(), kid.item()) if use_wandb: wandb.log({'FID': fid.item(), 'KID': kid.item()}, step=step) @@ -408,6 +411,12 @@ def save(): if args.wandb_save_model and use_wandb: wandb.save(filename) + if args.evaluate_only: + if not evaluate_enabled: + raise ValueError('--evaluate-only requested but evaluation is disabled') + evaluate() + return + losses_since_last_print = [] try: From f9d5a59426c9d2c64810b05e4395977d9106fd40 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sun, 22 Oct 2023 18:13:08 +0000 Subject: [PATCH 41/43] WIP flop counter --- k_diffusion/models/__init__.py | 1 + k_diffusion/models/image_transformer_v2.py | 44 ++++++++++++++-------- train.py | 13 ++++++- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/k_diffusion/models/__init__.py b/k_diffusion/models/__init__.py index 14cb6d2..2ee2244 100644 --- a/k_diffusion/models/__init__.py +++ b/k_diffusion/models/__init__.py @@ -1,3 +1,4 @@ +from . import flops from .flags import checkpointing, get_checkpointing from .image_v1 import ImageDenoiserModelV1 from .image_transformer_v1 import ImageTransformerDenoiserModelV1 diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 5f56ed7..4cab4f8 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -11,7 +11,7 @@ import torch._dynamo from torch.nn import functional as F -from . import flags +from . import flags, flops from .. import layers from .axial_rope import make_axial_pos @@ -123,12 +123,19 @@ def scale_for_cosine_sim_qkv(qkv, scale, eps): # Layers +class Linear(nn.Linear): + def forward(self, x): + flops.op(flops.op_linear, x.shape, self.weight.shape) + return super().forward(x) + + class LinearGEGLU(nn.Linear): def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features * 2, bias=bias) self.out_features = out_features def forward(self, x): + flops.op(flops.op_linear, x.shape, self.weight.shape) return linear_geglu(x, self.weight, self.bias) @@ -149,7 +156,7 @@ class AdaRMSNorm(nn.Module): def __init__(self, features, cond_features, eps=1e-6): super().__init__() self.eps = eps - self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False))) + self.linear = apply_wd(zero_init(Linear(cond_features, features, bias=False))) tag_module(self.linear, "mapping") def extra_repr(self): @@ -313,6 +320,7 @@ def apply_window_attention(window_size, window_shift, q, k, v, scale=None): mask = torch.reshape(mask, (h, w, wh * ww, wh * ww)) # do the attention here + flops.op(flops.op_attention, q_seqs.shape, k_seqs.shape, v_seqs.shape) qkv = F.scaled_dot_product_attention(q_seqs, k_seqs, v_seqs, mask, scale=scale) # unwindow @@ -341,11 +349,11 @@ def __init__(self, d_model, d_head, cond_features, dropout=0.0): self.d_head = d_head self.n_heads = d_model // d_head self.norm = AdaRMSNorm(d_model, cond_features) - self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) - self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) def extra_repr(self): return f"d_head={self.d_head}," @@ -361,6 +369,8 @@ def forward(self, x, pos, cond): qkv = scale_for_cosine_sim_qkv(qkv, self.scale, 1e-6) theta = torch.stack((theta, theta, torch.zeros_like(theta)), dim=-3) qkv = apply_rotary_emb_(qkv, theta) + flops_shape = qkv.shape[-5], qkv.shape[-2], qkv.shape[-4], qkv.shape[-1] + flops.op(flops.op_attention, flops_shape, flops_shape, flops_shape) x = flash_attn.flash_attn_qkvpacked_func(qkv, softmax_scale=1.0) x = rearrange(x, "n (h w) nh e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) else: @@ -369,6 +379,7 @@ def forward(self, x, pos, cond): theta = theta.movedim(-2, -3) q = apply_rotary_emb_(q, theta) k = apply_rotary_emb_(k, theta) + flops.op(flops.op_attention, q.shape, k.shape, v.shape) x = F.scaled_dot_product_attention(q, k, v, scale=1.0) x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) x = self.dropout(x) @@ -383,11 +394,11 @@ def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0): self.n_heads = d_model // d_head self.kernel_size = kernel_size self.norm = AdaRMSNorm(d_model, cond_features) - self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) - self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) def extra_repr(self): return f"d_head={self.d_head}, kernel_size={self.kernel_size}" @@ -403,6 +414,7 @@ def forward(self, x, pos, cond): k = apply_rotary_emb_(k, theta) if natten is None: raise ModuleNotFoundError("natten is required for neighborhood attention") + flops.op(flops.op_natten, q.shape, k.shape, v.shape, self.kernel_size) qk = natten.functional.natten2dqk(q, k, self.kernel_size, 1) a = torch.softmax(qk, dim=-1) x = natten.functional.natten2dav(a, v, self.kernel_size, 1) @@ -420,11 +432,11 @@ def __init__(self, d_model, d_head, cond_features, window_size, window_shift, dr self.window_size = window_size self.window_shift = window_shift self.norm = AdaRMSNorm(d_model, cond_features) - self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) + self.qkv_proj = apply_wd(Linear(d_model, d_model * 3, bias=False)) self.scale = nn.Parameter(torch.full([self.n_heads], 10.0)) self.pos_emb = AxialRoPE(d_head // 2, self.n_heads) self.dropout = nn.Dropout(dropout) - self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) + self.out_proj = apply_wd(zero_init(Linear(d_model, d_model, bias=False))) def extra_repr(self): return f"d_head={self.d_head}, window_size={self.window_size}, window_shift={self.window_shift}" @@ -451,7 +463,7 @@ def __init__(self, d_model, d_ff, cond_features, dropout=0.0): self.norm = AdaRMSNorm(d_model, cond_features) self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) self.dropout = nn.Dropout(dropout) - self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) + self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) def forward(self, x, cond): skip = x @@ -524,7 +536,7 @@ def __init__(self, d_model, d_ff, dropout=0.0): self.norm = RMSNorm(d_model) self.up_proj = apply_wd(LinearGEGLU(d_model, d_ff, bias=False)) self.dropout = nn.Dropout(dropout) - self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) + self.down_proj = apply_wd(zero_init(Linear(d_ff, d_model, bias=False))) def forward(self, x): skip = x @@ -557,7 +569,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = apply_wd(nn.Linear(in_features * self.h * self.w, out_features, bias=False)) + self.proj = apply_wd(Linear(in_features * self.h * self.w, out_features, bias=False)) def forward(self, x): x = rearrange(x, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=self.h, nw=self.w) @@ -569,7 +581,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False)) + self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) def forward(self, x): x = self.proj(x) @@ -581,7 +593,7 @@ def __init__(self, in_features, out_features, patch_size=(2, 2)): super().__init__() self.h = patch_size[0] self.w = patch_size[1] - self.proj = apply_wd(nn.Linear(in_features, out_features * self.h * self.w, bias=False)) + self.proj = apply_wd(Linear(in_features, out_features * self.h * self.w, bias=False)) self.fac = nn.Parameter(torch.ones(1) * 0.5) def forward(self, x, skip): @@ -639,11 +651,11 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c self.patch_in = TokenMerge(in_channels, levels[0].width, patch_size) self.time_emb = layers.FourierFeatures(1, mapping.width) - self.time_in_proj = nn.Linear(mapping.width, mapping.width, bias=False) + self.time_in_proj = Linear(mapping.width, mapping.width, bias=False) self.aug_emb = layers.FourierFeatures(9, mapping.width) - self.aug_in_proj = nn.Linear(mapping.width, mapping.width, bias=False) + self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False) self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None - self.mapping_cond_in_proj = nn.Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None + self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=dropout), "mapping") self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() diff --git a/train.py b/train.py index 6d85de1..2c77124 100755 --- a/train.py +++ b/train.py @@ -18,7 +18,7 @@ from torch import distributed as dist from torch import multiprocessing as mp from torch import optim -from torch.utils import data +from torch.utils import data, flop_counter from torchvision import datasets, transforms, utils from tqdm.auto import tqdm @@ -239,6 +239,17 @@ def main(): num_workers=args.num_workers, persistent_workers=True, pin_memory=True) inner_model, inner_model_ema, opt, train_dl = accelerator.prepare(inner_model, inner_model_ema, opt, train_dl) + + with torch.no_grad(), K.models.flops.flop_counter() as fc: + x = torch.zeros([1, model_config['input_channels'], size[0], size[1]], device=device) + sigma = torch.ones([1], device=device) + extra_args = {} + if getattr(unwrap(inner_model), "num_classes", 0): + extra_args['class_cond'] = torch.zeros([1], dtype=torch.long, device=device) + inner_model(x, sigma, **extra_args) + if accelerator.is_main_process: + print(f"Forward pass GFLOPs: {fc.flops / 1_000_000_000:,.3f}", flush=True) + if use_wandb: wandb.watch(inner_model) if accelerator.num_processes == 1: From f4cfe66ea10589062e2157d99654696beac2bf5f Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Fri, 27 Oct 2023 23:43:18 +0000 Subject: [PATCH 42/43] Per level dropout rate --- configs/config_oxford_flowers.json | 3 ++- .../config_oxford_flowers_shifted_window.json | 3 ++- k_diffusion/config.py | 14 ++++++++++---- k_diffusion/models/image_transformer_v2.py | 16 +++++++++------- 4 files changed, 23 insertions(+), 13 deletions(-) diff --git a/configs/config_oxford_flowers.json b/configs/config_oxford_flowers.json index bae25d7..fe3f759 100644 --- a/configs/config_oxford_flowers.json +++ b/configs/config_oxford_flowers.json @@ -13,7 +13,8 @@ ], "loss_config": "karras", "loss_weighting": "soft-min-snr", - "dropout_rate": 0.1, + "dropout_rate": [0.0, 0.0, 0.1], + "mapping_dropout_rate": 0.0, "augment_prob": 0.0, "sigma_data": 0.5, "sigma_min": 1e-2, diff --git a/configs/config_oxford_flowers_shifted_window.json b/configs/config_oxford_flowers_shifted_window.json index c292f9b..1bb2932 100644 --- a/configs/config_oxford_flowers_shifted_window.json +++ b/configs/config_oxford_flowers_shifted_window.json @@ -13,7 +13,8 @@ ], "loss_config": "karras", "loss_weighting": "soft-min-snr", - "dropout_rate": 0.1, + "dropout_rate": [0.0, 0.0, 0.1], + "mapping_dropout_rate": 0.0, "augment_prob": 0.0, "sigma_data": 0.5, "sigma_min": 1e-2, diff --git a/k_diffusion/config.py b/k_diffusion/config.py index bd8e3b7..8ab98b6 100644 --- a/k_diffusion/config.py +++ b/k_diffusion/config.py @@ -61,8 +61,10 @@ def load_config(path_or_dict): 'mapping_depth': 2, 'mapping_d_ff': None, 'mapping_cond_dim': 0, + 'mapping_dropout_rate': 0., 'd_ffs': None, 'self_attns': None, + 'dropout_rate': None, 'augment_wrapper': False, 'skip_stages': 0, 'has_variance': False, @@ -137,6 +139,10 @@ def load_config(path_or_dict): for i in range(len(config['model']['widths'])): self_attns.append(default_neighborhood if i < len(config['model']['widths']) - 1 else default_global) config['model']['self_attns'] = self_attns + if config['model']['dropout_rate'] is None: + config['model']['dropout_rate'] = [0.0] * len(config['model']['widths']) + elif isinstance(config['model']['dropout_rate'], float): + config['model']['dropout_rate'] = [config['model']['dropout_rate']] * len(config['model']['widths']) return merge(defaults, config) @@ -178,8 +184,9 @@ def make_model(config): assert len(config['widths']) == len(config['depths']) assert len(config['widths']) == len(config['d_ffs']) assert len(config['widths']) == len(config['self_attns']) + assert len(config['widths']) == len(config['dropout_rate']) levels = [] - for depth, width, d_ff, self_attn in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns']): + for depth, width, d_ff, self_attn, dropout in zip(config['depths'], config['widths'], config['d_ffs'], config['self_attns'], config['dropout_rate']): if self_attn['type'] == 'global': self_attn = models.image_transformer_v2.GlobalAttentionSpec(self_attn.get('d_head', 64)) elif self_attn['type'] == 'neighborhood': @@ -190,8 +197,8 @@ def make_model(config): self_attn = models.image_transformer_v2.NoAttentionSpec() else: raise ValueError(f'unsupported self attention type {self_attn["type"]}') - levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn)) - mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff']) + levels.append(models.image_transformer_v2.LevelSpec(depth, width, d_ff, self_attn, dropout)) + mapping = models.image_transformer_v2.MappingSpec(config['mapping_depth'], config['mapping_width'], config['mapping_d_ff'], config['mapping_dropout_rate']) model = models.ImageTransformerDenoiserModelV2( levels=levels, mapping=mapping, @@ -200,7 +207,6 @@ def make_model(config): patch_size=config['patch_size'], num_classes=num_classes + 1 if num_classes else 0, mapping_cond_dim=config['mapping_cond_dim'], - dropout=config['dropout_rate'], ) else: raise ValueError(f'unsupported model type {config["type"]}') diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 4cab4f8..f7ac209 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -474,7 +474,7 @@ def forward(self, x, cond): return x + skip -class TransformerLayer(nn.Module): +class GlobalTransformerLayer(nn.Module): def __init__(self, d_model, d_ff, d_head, cond_features, dropout=0.0): super().__init__() self.self_attn = SelfAttentionBlock(d_model, d_head, cond_features, dropout=dropout) @@ -632,6 +632,7 @@ class LevelSpec: width: int d_ff: int self_attn: Union[GlobalAttentionSpec, NeighborhoodAttentionSpec, ShiftedWindowAttentionSpec, NoAttentionSpec] + dropout: float @dataclass @@ -639,12 +640,13 @@ class MappingSpec: depth: int width: int d_ff: int + dropout: float # Model class class ImageTransformerDenoiserModelV2(nn.Module): - def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0, dropout=0.0): + def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_classes=0, mapping_cond_dim=0): super().__init__() self.num_classes = num_classes @@ -656,18 +658,18 @@ def __init__(self, levels, mapping, in_channels, out_channels, patch_size, num_c self.aug_in_proj = Linear(mapping.width, mapping.width, bias=False) self.class_emb = nn.Embedding(num_classes, mapping.width) if num_classes else None self.mapping_cond_in_proj = Linear(mapping_cond_dim, mapping.width, bias=False) if mapping_cond_dim else None - self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=dropout), "mapping") + self.mapping = tag_module(MappingNetwork(mapping.depth, mapping.width, mapping.d_ff, dropout=mapping.dropout), "mapping") self.down_levels, self.up_levels = nn.ModuleList(), nn.ModuleList() for i, spec in enumerate(levels): if isinstance(spec.self_attn, GlobalAttentionSpec): - layer_factory = lambda _: TransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=dropout) + layer_factory = lambda _: GlobalTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, dropout=spec.dropout) elif isinstance(spec.self_attn, NeighborhoodAttentionSpec): - layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=dropout) + layer_factory = lambda _: NeighborhoodTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.kernel_size, dropout=spec.dropout) elif isinstance(spec.self_attn, ShiftedWindowAttentionSpec): - layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=dropout) + layer_factory = lambda i: ShiftedWindowTransformerLayer(spec.width, spec.d_ff, spec.self_attn.d_head, mapping.width, spec.self_attn.window_size, i, dropout=spec.dropout) elif isinstance(spec.self_attn, NoAttentionSpec): - layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=dropout) + layer_factory = lambda _: NoAttentionTransformerLayer(spec.width, spec.d_ff, mapping.width, dropout=spec.dropout) else: raise ValueError(f"unsupported self attention spec {spec.self_attn}") From 9737cfd85120cba1258b5b5b1dc6511356b5c924 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Sat, 28 Oct 2023 14:23:44 +0000 Subject: [PATCH 43/43] Fix flop counter --- k_diffusion/models/flops.py | 54 +++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 k_diffusion/models/flops.py diff --git a/k_diffusion/models/flops.py b/k_diffusion/models/flops.py new file mode 100644 index 0000000..e98e2a8 --- /dev/null +++ b/k_diffusion/models/flops.py @@ -0,0 +1,54 @@ +from contextlib import contextmanager +import math +import threading + + +state = threading.local() +state.flop_counter = None + + +@contextmanager +def flop_counter(enable=True): + try: + old_flop_counter = state.flop_counter + state.flop_counter = FlopCounter() if enable else None + yield state.flop_counter + finally: + state.flop_counter = old_flop_counter + + +class FlopCounter: + def __init__(self): + self.ops = [] + + def op(self, op, *args, **kwargs): + self.ops.append((op, args, kwargs)) + + @property + def flops(self): + flops = 0 + for op, args, kwargs in self.ops: + flops += op(*args, **kwargs) + return flops + + +def op(op, *args, **kwargs): + if getattr(state, "flop_counter", None): + state.flop_counter.op(op, *args, **kwargs) + + +def op_linear(x, weight): + return math.prod(x) * weight[0] + + +def op_attention(q, k, v): + *b, s_q, d_q = q + *b, s_k, d_k = k + *b, s_v, d_v = v + return math.prod(b) * s_q * s_k * (d_q + d_v) + + +def op_natten(q, k, v, kernel_size): + *q_rest, d_q = q + *_, d_v = v + return math.prod(q_rest) * (d_q + d_v) * kernel_size**2