-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dallemini module v1 #416
Open
hexunlin
wants to merge
13
commits into
open-mmlab:dalle
Choose a base branch
from
hexunlin:dalle
base: dalle
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
dallemini module v1 #416
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7396ca3
dallemini_modules_v1
hexunlin 089e022
Update __init__.py
hexunlin a16319f
vqgan_modules_v1
hexunlin 9066ac5
Update modules.py
hexunlin 022d187
Update __init__.py
hexunlin cd403c4
update format
hexunlin af256a3
update_format2
hexunlin 619879d
update_format3
hexunlin c689d7b
update_format4
hexunlin c9b02bb
update_format5
hexunlin d35b897
update_format6
hexunlin 22ecc78
fixed_format_bugs
hexunlin 7a99c9b
vqvae_quantizer
hexunlin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .modules import DecoderLayer, EncoderLayer | ||
|
||
__all__ = ['BartDecoderLayer', 'BartEncoderLayer'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,233 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
import torch.nn as nn | ||
from mmcv.cnn.bricks import Linear, build_activation_layer, build_norm_layer | ||
from mmgen.registry import MODULES | ||
|
||
|
||
@MODULES.register_module() | ||
class GLU(nn.Module): | ||
"""GLU variants used to improve Transformer. | ||
|
||
Args: | ||
in_out_channels (int): The channel number of the input | ||
and the output feature map. | ||
mid_channels (int): The channel number of the middle layer feature map. | ||
""" | ||
|
||
def __init__(self, in_out_channels, mid_channels): | ||
super().__init__() | ||
_, self.norm1 = build_norm_layer(dict(type='LN'), in_out_channels) | ||
_, self.norm2 = build_norm_layer(dict(type='LN'), mid_channels) | ||
self.fc1 = Linear(in_out_channels, mid_channels, bias=False) | ||
self.fc2 = Linear(in_out_channels, mid_channels, bias=False) | ||
self.fc3 = Linear(mid_channels, in_out_channels, bias=False) | ||
self.gelu = build_activation_layer(dict(type='GELU')) | ||
|
||
def forward(self, z): | ||
"""Forward function. | ||
|
||
Args: | ||
z (torch.FloatTensor): Input feature map. | ||
|
||
Returns: | ||
z (torch.FloatTensor): Output feature map. | ||
""" | ||
z = self.norm1(z) | ||
w = self.fc1(z) | ||
w = self.gelu(w) | ||
v = self.fc2(z) | ||
z = self.norm2(w * v) | ||
z = self.fc3(z) | ||
return z | ||
|
||
|
||
@MODULES.register_module() | ||
class AttentionBase(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May this module meets our needs? |
||
"""An Muti-head Attention block used in Bart model. | ||
|
||
Ref: | ||
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models | ||
|
||
Args: | ||
in_channels (int): The channel number of the input feature map. | ||
num_heads (int): Number of heads in the attention. | ||
""" | ||
|
||
def __init__(self, in_channels, num_heads): | ||
super().__init__() | ||
self.in_channels = in_channels | ||
self.num_heads = num_heads | ||
self.querie = Linear(in_channels, in_channels, bias=False) | ||
self.key = Linear(in_channels, in_channels, bias=False) | ||
self.value = Linear(in_channels, in_channels, bias=False) | ||
self.proj = Linear(in_channels, in_channels, bias=False) | ||
|
||
def qkv(self, x): | ||
"""Calculate queries, keys and values for the embedding map. | ||
|
||
Args: | ||
x (torch.FloatTensor): Input feature map. | ||
|
||
Returns: | ||
q (torch.FloatTensor): Querie feature map. | ||
k (torch.FloatTensor): Key feature map. | ||
v (torch.FloatTensor): Value feature map. | ||
""" | ||
q = self.querie(x) | ||
k = self.key(x) | ||
v = self.value(x) | ||
|
||
return q, k, v | ||
|
||
def forward(self, q, k, v, attention_mask): | ||
"""Forward function for attention. | ||
|
||
Args: | ||
q (torch.FloatTensor): Querie feature map. | ||
k (torch.FloatTensor): Key feature map. | ||
v (torch.FloatTensor): Value feature map. | ||
attention_mask (torch.BoolTensor): whether to use | ||
an attention mask. | ||
|
||
Returns: | ||
weights (torch.FloatTensor): Feature map after attention. | ||
""" | ||
q = q.reshape(q.shape[:2] + (self.num_heads, -1)) | ||
q /= q.shape[-1]**0.5 | ||
k = k.reshape(k.shape[:2] + (self.num_heads, -1)) | ||
v = v.reshape(v.shape[:2] + (self.num_heads, -1)) | ||
|
||
attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12 | ||
weights = torch.einsum('bqhc,bkhc->bhqk', q, k) | ||
weights += attention_bias | ||
weights = torch.softmax(weights, -1) | ||
weights = torch.einsum('bhqk,bkhc->bqhc', weights, v) | ||
shape = weights.shape[:2] + (self.in_channels, ) | ||
weights = weights.reshape(shape) | ||
weights = self.proj(weights) | ||
return weights | ||
|
||
|
||
@MODULES.register_module() | ||
class BartEncoderLayer(nn.Module): | ||
# yapf: disable | ||
"""EncoderLayer of the Bart model. | ||
|
||
Ref: | ||
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models | ||
|
||
Args: | ||
in_channels (int): The channel number of the input feature map. | ||
head_num (int): Number of heads in the attention. | ||
out_channels (int): The channel number of the output feature map. | ||
""" | ||
|
||
def __init__(self, in_channels, head_num, out_channels): | ||
super().__init__() | ||
self.attn = AttentionBase(in_channels, head_num) | ||
_, self.norm = build_norm_layer(dict(type='LN'), in_channels) | ||
self.glu = GLU(in_channels, out_channels) | ||
|
||
def forward(self, x, attention_mask): | ||
"""Forward function for the encoder layer. | ||
|
||
Args: | ||
x (torch.FloatTensor): Input feature map. | ||
attention_mask (torch.BoolTensor): Whether to use | ||
an attention mask. | ||
|
||
Returns: | ||
x (torch.FloatTensor): Output feature map. | ||
""" | ||
|
||
h = self.norm(x) | ||
q, k, v = self.attn.qkv(h) | ||
h = self.attn(q, k, v, attention_mask) | ||
h = self.norm(h) | ||
x = x + h | ||
h = self.glu(x) | ||
x = x + h | ||
return x | ||
|
||
|
||
@MODULES.register_module() | ||
class BartDecoderLayer(nn.Module): | ||
# yapf: disable | ||
"""DecoderLayer of the Bart model. | ||
|
||
Ref: | ||
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models | ||
|
||
Args: | ||
in_channels (int): The channel number of the input feature map. | ||
head_num (int): Number of heads in the attention. | ||
out_channels (int): The channel number of the output feature map. | ||
token_length (int): The length of tokens. | ||
""" | ||
|
||
def __init__(self, in_channels, head_num, out_channels, token_length=256): | ||
super().__init__() | ||
self.attn = AttentionBase(in_channels, head_num) | ||
self.cross_attn = AttentionBase(in_channels, head_num) | ||
_, self.norm = build_norm_layer(dict(type='LN'), in_channels) | ||
self.glu = GLU(in_channels, out_channels) | ||
self.token_indices = torch.arange(token_length) | ||
|
||
def forward(self, x, encoder_state, attention_state, | ||
attention_mask, token_index): | ||
"""Forward function for the decoder layer. | ||
|
||
Args: | ||
x (torch.FloatTensor): Input feature map of | ||
the decoder embeddings. | ||
encoder_state (torch.FloatTensor): Input feature map of | ||
the encoder embeddings. | ||
attention_state (torch.FloatTensor): Input feature map of | ||
the attention. | ||
attention_mask (torch.BoolTensor): whether to use | ||
an attention mask. | ||
token_index (torch.LongTensor): The index of tokens. | ||
|
||
Returns: | ||
x (torch.FloatTensor): Output feature map of | ||
the decoder embeddings. | ||
attention_state (torch.FloatTensor): Output feature map of | ||
the attention. | ||
""" | ||
|
||
# Self Attention | ||
token_count = token_index.shape[1] | ||
if token_count == 1: | ||
self_attn_mask = self.token_indices <= token_index | ||
self_attn_mask = self_attn_mask[:, None, None, :] | ||
else: | ||
self_attn_mask = (self.token_indices[None, None, :token_count] <= | ||
token_index[:, :, None]) | ||
self_attn_mask = self_attn_mask[:, None, :, :] | ||
|
||
h = self.norm(x) | ||
q, k, v = self.attn.qkv(h) | ||
token_count = token_index.shape[1] | ||
if token_count == 1: | ||
batch_count = h.shape[0] | ||
attn_state_new = torch.cat([k, v]).to(attention_state.dtype) | ||
attention_state[:, token_index[0]] = attn_state_new | ||
k = attention_state[:batch_count] | ||
v = attention_state[batch_count:] | ||
h = self.attn(q, k, v, self_attn_mask) | ||
h = self.norm(h) | ||
x = x + h | ||
|
||
# Cross Attention | ||
h = self.norm(x) | ||
q, _, _ = self.cross_attn.qkv(h) | ||
_, k, v = self.cross_attn.qkv(h) | ||
h = self.cross_attn(q, k, v, attention_mask) | ||
h = self.norm(h) | ||
x = x + h | ||
|
||
h = self.glu(x) | ||
x = x + h | ||
|
||
return x, attention_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .modules import DiffusionResnetBlock | ||
|
||
__all__ = ['DiffusionResnetBlock'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn.bricks import Linear, build_conv_layer, build_norm_layer | ||
from mmgen.registry import MODULES | ||
|
||
|
||
@MODULES.register_module() | ||
class DiffusionResnetBlock(nn.Module): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this |
||
# yapf: disable | ||
"""Resblock for the diffusion model. If `in_channels` not equals to | ||
`out_channels`, a learnable shortcut with conv layers will be added. | ||
|
||
Ref: | ||
https://github.com/CompVis/taming-transformers/blob/master/taming/modules | ||
|
||
Args: | ||
in_channels (int): Number of channels of the input feature map. | ||
out_channels (int, optional): Number of output channels of the | ||
ResBlock. If not defined, the output channels will equal to the | ||
`in_channels`. Defaults to `None`. | ||
conv_shortcut (bool, optional): Whether to use conv_shortcut in | ||
convolution layers. Defaults to `False`. | ||
dropout (float): Probability of the dropout layers. | ||
temb_channels (int, optional): Number of channels of the input time embedding. | ||
Defaults to `512`. | ||
norm_cfg (dict, optional): Config for the norm of output layer. | ||
Defaults to dict(type='BN'). | ||
""" | ||
|
||
def __init__(self, | ||
*, | ||
in_channels, | ||
out_channels=None, | ||
conv_shortcut=False, | ||
dropout, | ||
temb_channels=512, | ||
norm_cfg=dict(type='GN', num_groups=32, eps=1e-6, | ||
affine=True)): | ||
super().__init__() | ||
self.in_channels = in_channels | ||
out_channels = in_channels if out_channels is None else out_channels | ||
self.out_channels = out_channels | ||
self.use_conv_shortcut = conv_shortcut | ||
self.silu = nn.SiLU() | ||
|
||
self.norm1 = build_norm_layer(norm_cfg, in_channels) | ||
self.conv1 = build_conv_layer(None, | ||
in_channels, | ||
out_channels, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1) | ||
if temb_channels > 0: | ||
self.temb_proj = Linear(temb_channels, out_channels) | ||
self.norm2 = build_norm_layer(norm_cfg, out_channels) | ||
self.dropout = nn.Dropout(dropout) | ||
self.conv2 = build_conv_layer(None, | ||
out_channels, | ||
out_channels, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1) | ||
if self.in_channels != self.out_channels: | ||
if self.use_conv_shortcut: | ||
self.conv_shortcut = build_conv_layer(None, | ||
in_channels, | ||
out_channels, | ||
kernel_size=3, | ||
stride=1, | ||
padding=1) | ||
else: | ||
self.nin_shortcut = build_conv_layer(None, | ||
in_channels, | ||
out_channels, | ||
kernel_size=1, | ||
stride=1, | ||
padding=0) | ||
|
||
def forward(self, x, temb): | ||
"""Forward function. | ||
|
||
Args: | ||
x (torch.Tensor): Input feature map tensor. | ||
temb (torch.Tensor): Shared time embedding. | ||
Returns: | ||
torch.Tensor : Output feature map tensor. | ||
""" | ||
h = self.norm1(x) | ||
h = self.silu(h) | ||
h = self.conv1(h) | ||
|
||
if temb is not None: | ||
h = h + self.temb_proj(self.silu(temb))[:, :, None, None] | ||
|
||
h = self.norm2(h) | ||
h = self.silu(h) | ||
h = self.dropout(h) | ||
h = self.conv2(h) | ||
|
||
if self.in_channels != self.out_channels: | ||
if self.use_conv_shortcut: | ||
x = self.conv_shortcut(x) | ||
else: | ||
x = self.nin_shortcut(x) | ||
|
||
return x + h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .quantizer import GumbelQuantize, VectorQuantizer, VectorQuantizer2 | ||
|
||
__all__ = ['GumbelQuantize', 'VectorQuantizer', 'VectorQuantizer2'] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May nn.GLU meets our needs?