Skip to content

Commit

Permalink
Support TE-DPA For Stable Diffusion (#10314)
Browse files Browse the repository at this point in the history
* [SD] Add te-dpa support

Signed-off-by: Wil Kong <[email protected]>

* [SD] Add te-dpa support, resolve compatibility with TE-master

Signed-off-by: Wil Kong <[email protected]>

* [SD] Add te-dpa support, add check for attention configs.

Signed-off-by: Wil Kong <[email protected]>

* Fix bugs of flash-attn and dpa in SD.

Signed-off-by: Wil Kong <[email protected]>

* Fix the issue of DPA API change.

Signed-off-by: Wil Kong <[email protected]>

* Apply isort and black reformatting

Signed-off-by: alpha0422 <[email protected]>
Signed-off-by: Wil Kong <[email protected]>

* [SD] TE-DPA: disbale use te-dpa in inference flow.

---------

Signed-off-by: Wil Kong <[email protected]>
Signed-off-by: alpha0422 <[email protected]>
Co-authored-by: Mengdi Wang <[email protected]>
  • Loading branch information
alpha0422 and Mengdi Wang authored Sep 16, 2024
1 parent 1b0f3af commit 4068955
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,9 @@ def load_from_checkpoint(
cfg.channels_last = True
if not cfg.get('capture_cudagraph_iters'):
cfg.capture_cudagraph_iters = -1
if cfg.get('unet_config') and cfg.get('unet_config').get('use_te_dpa'):
cfg.unet_config.use_te_dpa = False
cfg.unet_config.use_flash_attention = True

# compatibility for stable diffusion old checkpoint tweaks
first_key = list(checkpoint['state_dict'].keys())[0]
Expand Down Expand Up @@ -2242,6 +2245,14 @@ def load_from_checkpoint(
new_state_dict[new_key] = checkpoint['state_dict'][key]
checkpoint['state_dict'] = new_state_dict

# compatiblity for te-dpa in inference
if cfg.get('unet_config') and not cfg.get('unet_config').get('use_te_dpa'):
new_state_dict = {}
for key in checkpoint['state_dict'].keys():
if "_extra_state" not in key:
new_state_dict[key] = checkpoint['state_dict'][key]
checkpoint['state_dict'] = new_state_dict

if cfg.get('megatron_amp_O2', False):
new_state_dict = {}
for key in checkpoint['state_dict'].keys():
Expand Down
46 changes: 40 additions & 6 deletions nemo/collections/multimodal/modules/stable_diffusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from nemo.utils import logging

try:
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.module import LayerNormLinear, LayerNormMLP

HAVE_TE = True
Expand Down Expand Up @@ -255,11 +256,21 @@ def __init__(
dim_head=64,
dropout=0.0,
use_flash_attention=False,
use_te_dpa=False,
lora_network_alpha=None,
use_te=False,
):
super().__init__()

assert not (
use_te_dpa and use_flash_attention
), 'use_te_dpa and use_flash_attention cannot be True together. Please specify the attention you want to use.'

if use_flash_attention:
assert flash_attn_installed, 'Flash-attention must be installed.'
if use_te_dpa:
assert HAVE_TE, 'TransformerEngine is required to run with TE DPA.'

self.inner_dim = dim_head * heads
if context_dim is None:
self.is_self_attn = True
Expand All @@ -277,6 +288,7 @@ def __init__(
self.to_k = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)
self.to_v = LinearWrapper(context_dim, self.inner_dim, bias=False, lora_network_alpha=lora_network_alpha)

self.use_te_dpa = use_te_dpa
self.use_te = use_te
if use_te:
return_layernorm_output = True if self.is_self_attn else False
Expand All @@ -292,11 +304,21 @@ def __init__(
)
self.use_flash_attention = use_flash_attention

if dim_head <= 160 and (dim_head % 8) == 0 and flash_attn_installed:
if context_dim == query_dim:
self.flash_attn = FlashSelfAttention(softmax_scale=self.scale)
else:
self.flash_attn = FlashCrossAttention(softmax_scale=self.scale)
if dim_head <= 160 and (dim_head % 8) == 0:
if self.use_flash_attention:
if context_dim == query_dim:
self.flash_attn = FlashSelfAttention(softmax_scale=self.scale)
else:
self.flash_attn = FlashCrossAttention(softmax_scale=self.scale)
elif self.use_te_dpa:
self.te_dpa = DotProductAttention(
kv_channels=dim_head,
num_attention_heads=self.inner_dim // dim_head,
attn_mask_type='no_mask',
attention_type='self' if context_dim == query_dim else 'cross',
qkv_format='bshd', # `sbhd`, `bshd`, `thd`
softmax_scale=self.scale,
)

def forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0):
h = self.heads
Expand Down Expand Up @@ -338,7 +360,7 @@ def _attention(self, q, k, v, mask=None, additional_tokens=None):

if (
not flash_attn_installed
or not self.use_flash_attention
or (not self.use_flash_attention and not self.use_te_dpa)
or q.dtype == torch.float32
or (self.dim_head > 160 or (self.dim_head % 8) != 0)
or mask is not None
Expand All @@ -365,6 +387,13 @@ def _attention(self, q, k, v, mask=None, additional_tokens=None):

# (b h) n d -> b n (h d)
out = rearrange_heads_inner(out, h)

elif self.use_te_dpa:
b, s_kv, hd = k.shape
s_q = q.shape[1]
d = hd // h
out = self.te_dpa(q.view(b, s_q, h, d), k.view(b, s_kv, h, d), v.view(b, s_kv, h, d))

elif self.context_dim == self.query_dim:
# self-attention
qkv = torch.stack([q, k, v], dim=2)
Expand Down Expand Up @@ -404,6 +433,7 @@ def __init__(
gated_ff=True,
use_checkpoint=False,
use_flash_attention=False,
use_te_dpa=False,
disable_self_attn=False,
lora_network_alpha=None,
use_te=False,
Expand All @@ -416,6 +446,7 @@ def __init__(
dim_head=d_head,
dropout=dropout,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
context_dim=context_dim if self.disable_self_attn else None,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
Expand All @@ -428,6 +459,7 @@ def __init__(
dim_head=d_head,
dropout=dropout,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
) # is self-attn if context is none
Expand Down Expand Up @@ -485,6 +517,7 @@ def __init__(
use_linear=False,
use_checkpoint=False,
use_flash_attention=False,
use_te_dpa=False,
lora_network_alpha=None,
use_te=False,
):
Expand Down Expand Up @@ -527,6 +560,7 @@ def __init__(
context_dim=context_dim[d],
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
disable_self_attn=disable_self_attn,
lora_network_alpha=lora_network_alpha,
use_te=use_te,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def __init__(
from_NeMo=False,
# It must be specified when from pretrained is not None. It indicates loading unet from NeMo trained ckpt or HF
use_flash_attention: bool = False,
use_te_dpa: bool = False,
unet_precision: str = "fp32",
lora_network_alpha=None,
timesteps=1000,
Expand Down Expand Up @@ -782,6 +783,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
lora_network_alpha=lora_network_alpha,
use_te=self.use_te_fp8,
)
Expand Down Expand Up @@ -851,6 +853,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
use_te=self.use_te_fp8,
lora_network_alpha=lora_network_alpha,
)
Expand Down Expand Up @@ -918,6 +921,7 @@ def __init__(
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
use_flash_attention=use_flash_attention,
use_te_dpa=use_te_dpa,
lora_network_alpha=lora_network_alpha,
use_te=self.use_te_fp8,
)
Expand Down Expand Up @@ -978,8 +982,8 @@ def __init__(
self.convert_to_fp16()
elif unet_precision == 'fp16':
self.convert_to_fp16(enable_norm_layers=True)
elif self.use_te_fp8:
assert unet_precision != 'fp16', "fp8 training can't work with fp16 O2 amp recipe"
if self.use_te_fp8:
assert unet_precision == 'fp16', "fp8 training can't work with fp16 O2 amp recipe"
convert_module_to_fp8(self)

fp8_margin = int(os.getenv("FP8_MARGIN", '0'))
Expand All @@ -1002,6 +1006,7 @@ def __init__(
amax_history_len=fp8_amax_history_len,
amax_compute_algo=fp8_amax_compute_algo,
override_linear_precision=(False, False, not fp8_wgrad),
# fp8_dpa=use_te_dpa, # TODO; fp8 DPA kernel is not supported now.
)
old_state_dict = self.state_dict()
new_state_dict = self.te_fp8_key_mapping(old_state_dict)
Expand Down

0 comments on commit 4068955

Please sign in to comment.