Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dallemini module v1 #416

Open
wants to merge 13 commits into
base: dalle
Choose a base branch
from
Open

dallemini module v1 #416

wants to merge 13 commits into from

Conversation

hexunlin
Copy link
Collaborator

@hexunlin hexunlin commented Aug 30, 2022

dallemini module v1 + vqgan module v1

@hexunlin hexunlin requested a review from plyfager August 30, 2022 20:34


@MODULES.register_module()
class GLU(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May nn.GLU meets our needs?



@MODULES.register_module()
class EncoderLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this encoder layer is used for Bart, BartEncoderLayer may be a better name.



@MODULES.register_module()
class DecoderLayer(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this decoder layer is used for Bart, BartDecoderLayer may be a better name.



@MODULES.register_module()
class AttentionBase(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May this module meets our needs?


def __init__(self, in_out_channels, mid_channels):
super().__init__()
self.norm1 = build_norm_layer(dict(type='LN'), in_out_channels)[1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_, self.norm1 = build_norm_layer() seems better


def __init__(self, in_channels, head_num, out_channels):
super().__init__()
self.selfAttention = AttentionBase(in_channels, head_num)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can just name it self.attn.

x = self.selfAttention(q, k, v, attention_mask)
x = self.norm(x)
x = residual + x
residual = x.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact you can just write code in this way

        h = self.glu(x)
        x = h + x

instead of

        residual = x.clone()
        x = self.glu(x)
        x = residual + x

x (torch.FloatTensor): Output feature map.
"""
residual = x.clone()
x = self.norm(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

write forward in this way

h = self.norm(x)
h = xxx(h)
x = x + h 

without using clone.

self.crossAttention = AttentionBase(in_channels, head_num)
self.norm = build_norm_layer(dict(type='LN'), in_channels)[1]
self.glu = GLU(in_channels, out_channels)
self.token_indices = torch.arange(256, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may set 256 as an argument in init function.

in_channels (int): The channel number of the input feature map.
head_num (int): Number of heads in the attention.
out_channels (int): The channel number of the output feature map.
device (str): The type of device (cpu or cuda).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in fact, device is not supposed to be set in init function. MMCV or MMEngine will put model in correct device. Or you can just use model.to(device) outside.
If you really need to get the device of a model. Call get_module_device. Or you may use type_as for tensors.

from mmgen.registry import MODULES


def nonlinearity(x):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may we use nn.silu?

return x * activate(x)


def Normalize(in_channels):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not need to add an extra function here. Just use build_norm_layer in your code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_cfg can be set as an argument.



@MODULES.register_module()
class DiffusionDownsample(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether we call this module DiffusionDownsample.😂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is a single stride2 conv or avg_pool. We may not add an extra class here.



@MODULES.register_module()
class DiffusionResnetBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this resblock is the same as diffusion unet, you may find it in diffusion architecture.

Copy link
Collaborator

@plyfager plyfager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MMGen have supported DDPM, you may see whether you can reuse its modules.(unet, downsample, upsample).

@plyfager
Copy link
Collaborator

plyfager commented Sep 7, 2022

For above comments, you may check whether other lines also have same problem and fix them.

Fixed format in dalle_mini and vqgan modules; Extended Downsample in ddpm; GLU can't be replaced by nn.glu(); Temporary keep AttentionBase and DiffusionResblock (needs further testing).
quantizer module for vqgan
@zengyh1900 zengyh1900 added awaiting response kind/feature request new feature/model/datasets/config etc. priority/P0 highest priority labels Oct 12, 2022
@zengyh1900 zengyh1900 added this to the Backlog milestone Oct 12, 2022
@plyfager plyfager added status/WIP work in progress normally and removed awaiting response labels Oct 13, 2022
@OpenMMLab-Assistant005
Copy link

Hi @hexunlin !We are grateful for your efforts in helping improve this open-source project during your personal time.
Welcome to join OpenMMLab Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA
If you have a WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution❤

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kind/feature request new feature/model/datasets/config etc. priority/P0 highest priority status/WIP work in progress normally
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants