-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dallemini module v1 #416
base: dalle
Are you sure you want to change the base?
dallemini module v1 #416
Conversation
dallemini modules v1
vqgan modules v1
update format2
update_format3
update format4
update format
update format6
|
||
|
||
@MODULES.register_module() | ||
class GLU(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May nn.GLU meets our needs?
|
||
|
||
@MODULES.register_module() | ||
class EncoderLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this encoder layer is used for Bart, BartEncoderLayer
may be a better name.
|
||
|
||
@MODULES.register_module() | ||
class DecoderLayer(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this decoder layer is used for Bart, BartDecoderLayer may be a better name.
|
||
|
||
@MODULES.register_module() | ||
class AttentionBase(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May this module meets our needs?
|
||
def __init__(self, in_out_channels, mid_channels): | ||
super().__init__() | ||
self.norm1 = build_norm_layer(dict(type='LN'), in_out_channels)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_, self.norm1 = build_norm_layer() seems better
|
||
def __init__(self, in_channels, head_num, out_channels): | ||
super().__init__() | ||
self.selfAttention = AttentionBase(in_channels, head_num) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can just name it self.attn
.
x = self.selfAttention(q, k, v, attention_mask) | ||
x = self.norm(x) | ||
x = residual + x | ||
residual = x.clone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact you can just write code in this way
h = self.glu(x)
x = h + x
instead of
residual = x.clone()
x = self.glu(x)
x = residual + x
x (torch.FloatTensor): Output feature map. | ||
""" | ||
residual = x.clone() | ||
x = self.norm(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
write forward in this way
h = self.norm(x)
h = xxx(h)
x = x + h
without using clone
.
self.crossAttention = AttentionBase(in_channels, head_num) | ||
self.norm = build_norm_layer(dict(type='LN'), in_channels)[1] | ||
self.glu = GLU(in_channels, out_channels) | ||
self.token_indices = torch.arange(256, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may set 256
as an argument in init function.
in_channels (int): The channel number of the input feature map. | ||
head_num (int): Number of heads in the attention. | ||
out_channels (int): The channel number of the output feature map. | ||
device (str): The type of device (cpu or cuda). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fact, device is not supposed to be set in init function. MMCV or MMEngine will put model in correct device. Or you can just use model.to(device)
outside.
If you really need to get the device of a model. Call get_module_device
. Or you may use type_as
for tensors.
from mmgen.registry import MODULES | ||
|
||
|
||
def nonlinearity(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
may we use nn.silu
?
return x * activate(x) | ||
|
||
|
||
def Normalize(in_channels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not need to add an extra function here. Just use build_norm_layer in your code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_cfg can be set as an argument.
|
||
|
||
@MODULES.register_module() | ||
class DiffusionDownsample(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering whether we call this module DiffusionDownsample
.😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module is a single stride2 conv or avg_pool. We may not add an extra class here.
|
||
|
||
@MODULES.register_module() | ||
class DiffusionResnetBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this resblock
is the same as diffusion unet, you may find it in diffusion architecture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MMGen have supported DDPM, you may see whether you can reuse its modules.(unet, downsample, upsample).
For above comments, you may check whether other lines also have same problem and fix them. |
Fixed format in dalle_mini and vqgan modules; Extended Downsample in ddpm; GLU can't be replaced by nn.glu(); Temporary keep AttentionBase and DiffusionResblock (needs further testing).
quantizer module for vqgan
Hi @hexunlin !We are grateful for your efforts in helping improve this open-source project during your personal time. |
dallemini module v1 + vqgan module v1