Skip to content

Commit

Permalink
add xformers flag support, improved code
Browse files Browse the repository at this point in the history
  • Loading branch information
Asif Ahmed committed Apr 18, 2023
1 parent dd55c53 commit e450b8e
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 74 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name='vqcompress',
author='Asif Ahmed',
description='Image compression with vqgan, autoencoder etc.',
version='0.1.6',
version='0.1.7',
url='https://github.com/quickgrid/vq-compress',
packages=find_packages(),
classifiers=[
Expand Down
30 changes: 16 additions & 14 deletions vqcompress/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import pathlib
from pathlib import Path
from typing import Tuple
from typing import Tuple, List

import numpy as np
import torch
Expand All @@ -16,6 +16,7 @@
from tqdm import tqdm

from vqcompress.core.ldm.util import instantiate_from_config
import vqcompress.core.ldm.model

torch.set_grad_enabled(False)

Expand Down Expand Up @@ -180,24 +181,25 @@ def __init__(
sd = pl_sd["state_dict"]
sd_keys = sd.keys()

def delete_model_layers(layer_initial: str):
key_delete_list = []
for dkey in sd_keys:
if dkey.split('.')[0] == layer_initial:
key_delete_list.append(dkey)
def delete_model_layers(layer_initial_list: List[str]):
for layer_initial in layer_initial_list:
key_delete_list = []
for dkey in sd_keys:
if dkey.split('.')[0] == layer_initial:
key_delete_list.append(dkey)

for k in key_delete_list:
del sd[f'{k}']
for k in key_delete_list:
del sd[f'{k}']

for i in ['loss', 'model_ema']:
delete_model_layers(i)
delete_model_layers(['loss', 'model_ema'])

if use_decompress:
for i in ['quant_conv', 'encoder']:
delete_model_layers(i)
delete_model_layers(['quant_conv', 'encoder'])
else:
for i in ['post_quant_conv', 'decoder']:
delete_model_layers(i)
delete_model_layers(['post_quant_conv', 'decoder'])

if use_xformers:
vqcompress.core.ldm.model.AttnBlock.forward = vqcompress.core.ldm.model.patch_xformers_attn_forward

# print(sd.keys())
self.ldm_model = instantiate_from_config(config.model)
Expand Down
80 changes: 21 additions & 59 deletions vqcompress/core/ldm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,65 +151,29 @@ def get_xformers_flash_attention_op(q, k, v):
return None


class AttnBlockXformers(nn.Module):
def patch_xformers_attn_forward(self, x):
"""Copied from,
https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_hijack_optimizations.py.
"""

def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels

self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.k = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.v = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0
)

def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
# dtype = q.dtype
# if True:
# q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = xformers.ops.memory_efficient_attention(q, k, v)
# out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
# dtype = q.dtype
# if True:
# q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
# out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = xformers.ops.memory_efficient_attention(q, k, v)
# out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out


class AttnBlock(nn.Module):
Expand Down Expand Up @@ -275,12 +239,10 @@ def forward(self, x):


def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "xformers", "none"], f'attn_type {attn_type} unknown'
assert attn_type in ["vanilla", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
if attn_type == "xformers":
return AttnBlockXformers(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)

Expand Down

0 comments on commit e450b8e

Please sign in to comment.