From e450b8e50ca3f97250b676d72b056470b764203d Mon Sep 17 00:00:00 2001 From: Asif Ahmed Date: Tue, 18 Apr 2023 19:14:26 +0600 Subject: [PATCH] add xformers flag support, improved code --- setup.py | 2 +- vqcompress/compression.py | 30 +++++++------- vqcompress/core/ldm/model.py | 80 ++++++++++-------------------------- 3 files changed, 38 insertions(+), 74 deletions(-) diff --git a/setup.py b/setup.py index dfaabf2..661dbcd 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ name='vqcompress', author='Asif Ahmed', description='Image compression with vqgan, autoencoder etc.', - version='0.1.6', + version='0.1.7', url='https://github.com/quickgrid/vq-compress', packages=find_packages(), classifiers=[ diff --git a/vqcompress/compression.py b/vqcompress/compression.py index f977511..b16bc0c 100644 --- a/vqcompress/compression.py +++ b/vqcompress/compression.py @@ -3,7 +3,7 @@ import os import pathlib from pathlib import Path -from typing import Tuple +from typing import Tuple, List import numpy as np import torch @@ -16,6 +16,7 @@ from tqdm import tqdm from vqcompress.core.ldm.util import instantiate_from_config +import vqcompress.core.ldm.model torch.set_grad_enabled(False) @@ -180,24 +181,25 @@ def __init__( sd = pl_sd["state_dict"] sd_keys = sd.keys() - def delete_model_layers(layer_initial: str): - key_delete_list = [] - for dkey in sd_keys: - if dkey.split('.')[0] == layer_initial: - key_delete_list.append(dkey) + def delete_model_layers(layer_initial_list: List[str]): + for layer_initial in layer_initial_list: + key_delete_list = [] + for dkey in sd_keys: + if dkey.split('.')[0] == layer_initial: + key_delete_list.append(dkey) - for k in key_delete_list: - del sd[f'{k}'] + for k in key_delete_list: + del sd[f'{k}'] - for i in ['loss', 'model_ema']: - delete_model_layers(i) + delete_model_layers(['loss', 'model_ema']) if use_decompress: - for i in ['quant_conv', 'encoder']: - delete_model_layers(i) + delete_model_layers(['quant_conv', 'encoder']) else: - for i in ['post_quant_conv', 'decoder']: - delete_model_layers(i) + delete_model_layers(['post_quant_conv', 'decoder']) + + if use_xformers: + vqcompress.core.ldm.model.AttnBlock.forward = vqcompress.core.ldm.model.patch_xformers_attn_forward # print(sd.keys()) self.ldm_model = instantiate_from_config(config.model) diff --git a/vqcompress/core/ldm/model.py b/vqcompress/core/ldm/model.py index 5ba1597..6d6e86e 100644 --- a/vqcompress/core/ldm/model.py +++ b/vqcompress/core/ldm/model.py @@ -151,65 +151,29 @@ def get_xformers_flash_attention_op(q, k, v): return None -class AttnBlockXformers(nn.Module): +def patch_xformers_attn_forward(self, x): """Copied from, https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_hijack_optimizations.py. """ - - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.k = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.v = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - self.proj_out = torch.nn.Conv2d( - in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0 - ) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - b, c, h, w = q.shape - q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) - # dtype = q.dtype - # if True: - # q, k = q.float(), k.float() - q = q.contiguous() - k = k.contiguous() - v = v.contiguous() - # out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) - out = xformers.ops.memory_efficient_attention(q, k, v) - # out = out.to(dtype) - out = rearrange(out, 'b (h w) c -> b c h w', h=h) - out = self.proj_out(out) - return x + out + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + b, c, h, w = q.shape + q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + # dtype = q.dtype + # if True: + # q, k = q.float(), k.float() + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + # out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = xformers.ops.memory_efficient_attention(q, k, v) + # out = out.to(dtype) + out = rearrange(out, 'b (h w) c -> b c h w', h=h) + out = self.proj_out(out) + return x + out class AttnBlock(nn.Module): @@ -275,12 +239,10 @@ def forward(self, x): def make_attn(in_channels, attn_type="vanilla"): - assert attn_type in ["vanilla", "xformers", "none"], f'attn_type {attn_type} unknown' + assert attn_type in ["vanilla", "none"], f'attn_type {attn_type} unknown' print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return AttnBlock(in_channels) - if attn_type == "xformers": - return AttnBlockXformers(in_channels) elif attn_type == "none": return nn.Identity(in_channels)