Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dallemini module v1 #416

Open
wants to merge 13 commits into
base: dalle
Choose a base branch
from
4 changes: 4 additions & 0 deletions mmgen/models/architectures/dalle_mini/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .modules import DecoderLayer, EncoderLayer

__all__ = ['BartDecoderLayer', 'BartEncoderLayer']
233 changes: 233 additions & 0 deletions mmgen/models/architectures/dalle_mini/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn.bricks import Linear, build_activation_layer, build_norm_layer
from mmgen.registry import MODULES


@MODULES.register_module()
class GLU(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May nn.GLU meets our needs?

"""GLU variants used to improve Transformer.

Args:
in_out_channels (int): The channel number of the input
and the output feature map.
mid_channels (int): The channel number of the middle layer feature map.
"""

def __init__(self, in_out_channels, mid_channels):
super().__init__()
_, self.norm1 = build_norm_layer(dict(type='LN'), in_out_channels)
_, self.norm2 = build_norm_layer(dict(type='LN'), mid_channels)
self.fc1 = Linear(in_out_channels, mid_channels, bias=False)
self.fc2 = Linear(in_out_channels, mid_channels, bias=False)
self.fc3 = Linear(mid_channels, in_out_channels, bias=False)
self.gelu = build_activation_layer(dict(type='GELU'))

def forward(self, z):
"""Forward function.

Args:
z (torch.FloatTensor): Input feature map.

Returns:
z (torch.FloatTensor): Output feature map.
"""
z = self.norm1(z)
w = self.fc1(z)
w = self.gelu(w)
v = self.fc2(z)
z = self.norm2(w * v)
z = self.fc3(z)
return z


@MODULES.register_module()
class AttentionBase(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May this module meets our needs?

"""An Muti-head Attention block used in Bart model.

Ref:
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models

Args:
in_channels (int): The channel number of the input feature map.
num_heads (int): Number of heads in the attention.
"""

def __init__(self, in_channels, num_heads):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.querie = Linear(in_channels, in_channels, bias=False)
self.key = Linear(in_channels, in_channels, bias=False)
self.value = Linear(in_channels, in_channels, bias=False)
self.proj = Linear(in_channels, in_channels, bias=False)

def qkv(self, x):
"""Calculate queries, keys and values for the embedding map.

Args:
x (torch.FloatTensor): Input feature map.

Returns:
q (torch.FloatTensor): Querie feature map.
k (torch.FloatTensor): Key feature map.
v (torch.FloatTensor): Value feature map.
"""
q = self.querie(x)
k = self.key(x)
v = self.value(x)

return q, k, v

def forward(self, q, k, v, attention_mask):
"""Forward function for attention.

Args:
q (torch.FloatTensor): Querie feature map.
k (torch.FloatTensor): Key feature map.
v (torch.FloatTensor): Value feature map.
attention_mask (torch.BoolTensor): whether to use
an attention mask.

Returns:
weights (torch.FloatTensor): Feature map after attention.
"""
q = q.reshape(q.shape[:2] + (self.num_heads, -1))
q /= q.shape[-1]**0.5
k = k.reshape(k.shape[:2] + (self.num_heads, -1))
v = v.reshape(v.shape[:2] + (self.num_heads, -1))

attention_bias = (1 - attention_mask.to(torch.float32)) * -1e12
weights = torch.einsum('bqhc,bkhc->bhqk', q, k)
weights += attention_bias
weights = torch.softmax(weights, -1)
weights = torch.einsum('bhqk,bkhc->bqhc', weights, v)
shape = weights.shape[:2] + (self.in_channels, )
weights = weights.reshape(shape)
weights = self.proj(weights)
return weights


@MODULES.register_module()
class BartEncoderLayer(nn.Module):
# yapf: disable
"""EncoderLayer of the Bart model.

Ref:
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models

Args:
in_channels (int): The channel number of the input feature map.
head_num (int): Number of heads in the attention.
out_channels (int): The channel number of the output feature map.
"""

def __init__(self, in_channels, head_num, out_channels):
super().__init__()
self.attn = AttentionBase(in_channels, head_num)
_, self.norm = build_norm_layer(dict(type='LN'), in_channels)
self.glu = GLU(in_channels, out_channels)

def forward(self, x, attention_mask):
"""Forward function for the encoder layer.

Args:
x (torch.FloatTensor): Input feature map.
attention_mask (torch.BoolTensor): Whether to use
an attention mask.

Returns:
x (torch.FloatTensor): Output feature map.
"""

h = self.norm(x)
q, k, v = self.attn.qkv(h)
h = self.attn(q, k, v, attention_mask)
h = self.norm(h)
x = x + h
h = self.glu(x)
x = x + h
return x


@MODULES.register_module()
class BartDecoderLayer(nn.Module):
# yapf: disable
"""DecoderLayer of the Bart model.

Ref:
https://github.com/kuprel/min-dalle/blob/main/min_dalle/models

Args:
in_channels (int): The channel number of the input feature map.
head_num (int): Number of heads in the attention.
out_channels (int): The channel number of the output feature map.
token_length (int): The length of tokens.
"""

def __init__(self, in_channels, head_num, out_channels, token_length=256):
super().__init__()
self.attn = AttentionBase(in_channels, head_num)
self.cross_attn = AttentionBase(in_channels, head_num)
_, self.norm = build_norm_layer(dict(type='LN'), in_channels)
self.glu = GLU(in_channels, out_channels)
self.token_indices = torch.arange(token_length)

def forward(self, x, encoder_state, attention_state,
attention_mask, token_index):
"""Forward function for the decoder layer.

Args:
x (torch.FloatTensor): Input feature map of
the decoder embeddings.
encoder_state (torch.FloatTensor): Input feature map of
the encoder embeddings.
attention_state (torch.FloatTensor): Input feature map of
the attention.
attention_mask (torch.BoolTensor): whether to use
an attention mask.
token_index (torch.LongTensor): The index of tokens.

Returns:
x (torch.FloatTensor): Output feature map of
the decoder embeddings.
attention_state (torch.FloatTensor): Output feature map of
the attention.
"""

# Self Attention
token_count = token_index.shape[1]
if token_count == 1:
self_attn_mask = self.token_indices <= token_index
self_attn_mask = self_attn_mask[:, None, None, :]
else:
self_attn_mask = (self.token_indices[None, None, :token_count] <=
token_index[:, :, None])
self_attn_mask = self_attn_mask[:, None, :, :]

h = self.norm(x)
q, k, v = self.attn.qkv(h)
token_count = token_index.shape[1]
if token_count == 1:
batch_count = h.shape[0]
attn_state_new = torch.cat([k, v]).to(attention_state.dtype)
attention_state[:, token_index[0]] = attn_state_new
k = attention_state[:batch_count]
v = attention_state[batch_count:]
h = self.attn(q, k, v, self_attn_mask)
h = self.norm(h)
x = x + h

# Cross Attention
h = self.norm(x)
q, _, _ = self.cross_attn.qkv(h)
_, k, v = self.cross_attn.qkv(h)
h = self.cross_attn(q, k, v, attention_mask)
h = self.norm(h)
x = x + h

h = self.glu(x)
x = x + h

return x, attention_state
8 changes: 7 additions & 1 deletion mmgen/models/architectures/ddpm/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,13 @@ class DenoisingDownsample(nn.Module):
downsampled.
with_conv (bool, optional): Whether use convolution operation for
downsampling. Defaults to `True`.
with_pad (bool, optional): Whether do asymmetric padding for
downsampling. Defaults to `False`.
"""

def __init__(self, in_channels, with_conv=True):
def __init__(self, in_channels, with_conv=True, with_pad=False):
super().__init__()
self.with_pad = with_pad
if with_conv:
self.downsample = nn.Conv2d(in_channels, in_channels, 3, 2, 1)
else:
Expand All @@ -383,6 +386,9 @@ def forward(self, x):
Returns:
torch.Tensor: Feature map after downsampling.
"""
if self.with_pad:
# do asymmetric padding
x = F.pad(x, (0, 1, 0, 1), mode='constant', value=0)
return self.downsample(x)


Expand Down
4 changes: 4 additions & 0 deletions mmgen/models/architectures/vqgan/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .modules import DiffusionResnetBlock

__all__ = ['DiffusionResnetBlock']
107 changes: 107 additions & 0 deletions mmgen/models/architectures/vqgan/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import Linear, build_conv_layer, build_norm_layer
from mmgen.registry import MODULES


@MODULES.register_module()
class DiffusionResnetBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this resblock is the same as diffusion unet, you may find it in diffusion architecture.

# yapf: disable
"""Resblock for the diffusion model. If `in_channels` not equals to
`out_channels`, a learnable shortcut with conv layers will be added.

Ref:
https://github.com/CompVis/taming-transformers/blob/master/taming/modules

Args:
in_channels (int): Number of channels of the input feature map.
out_channels (int, optional): Number of output channels of the
ResBlock. If not defined, the output channels will equal to the
`in_channels`. Defaults to `None`.
conv_shortcut (bool, optional): Whether to use conv_shortcut in
convolution layers. Defaults to `False`.
dropout (float): Probability of the dropout layers.
temb_channels (int, optional): Number of channels of the input time embedding.
Defaults to `512`.
norm_cfg (dict, optional): Config for the norm of output layer.
Defaults to dict(type='BN').
"""

def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
norm_cfg=dict(type='GN', num_groups=32, eps=1e-6,
affine=True)):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.silu = nn.SiLU()

self.norm1 = build_norm_layer(norm_cfg, in_channels)
self.conv1 = build_conv_layer(None,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = Linear(temb_channels, out_channels)
self.norm2 = build_norm_layer(norm_cfg, out_channels)
self.dropout = nn.Dropout(dropout)
self.conv2 = build_conv_layer(None,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = build_conv_layer(None,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = build_conv_layer(None,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)

def forward(self, x, temb):
"""Forward function.

Args:
x (torch.Tensor): Input feature map tensor.
temb (torch.Tensor): Shared time embedding.
Returns:
torch.Tensor : Output feature map tensor.
"""
h = self.norm1(x)
h = self.silu(h)
h = self.conv1(h)

if temb is not None:
h = h + self.temb_proj(self.silu(temb))[:, :, None, None]

h = self.norm2(h)
h = self.silu(h)
h = self.dropout(h)
h = self.conv2(h)

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)

return x + h
4 changes: 4 additions & 0 deletions mmgen/models/architectures/vqvae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .quantizer import GumbelQuantize, VectorQuantizer, VectorQuantizer2

__all__ = ['GumbelQuantize', 'VectorQuantizer', 'VectorQuantizer2']
Loading