Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About text mask #32

Open
ZeyuLing opened this issue Jul 30, 2023 · 0 comments
Open

About text mask #32

ZeyuLing opened this issue Jul 30, 2023 · 0 comments

Comments

@ZeyuLing
Copy link

ZeyuLing commented Jul 30, 2023

Thanks a lot for your paper and code!
In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.

def encode_text(self, text, device):
with torch.no_grad():
text = clip.tokenize(text, truncate=True).to(device)
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim]
x = x + self.clip.positional_embedding.type(self.clip.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.transformer(x)
x = self.clip.ln_final(x).type(self.clip.dtype)
# T, B, D
x = self.text_pre_proj(x)
xf_out = self.textTransEncoder(x)
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out

class LinearTemporalCrossAttention(nn.Module):\

  def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
      super().__init__()
      self.num_head = num_head
      self.norm = nn.LayerNorm(latent_dim)
      self.text_norm = nn.LayerNorm(text_latent_dim)
      self.query = nn.Linear(latent_dim, latent_dim)
      self.key = nn.Linear(text_latent_dim, latent_dim)
      self.value = nn.Linear(text_latent_dim, latent_dim)
      self.dropout = nn.Dropout(dropout)
      self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)

  def forward(self, x, xf, emb):
      """
      x: B, T, D
      xf: B, N, L
      """
      B, T, D = x.shape
      N = xf.shape[1]
      H = self.num_head
      # B, T, D
      query = self.query(self.norm(x))
      # B, N, D
      key = self.key(self.text_norm(xf))
      query = F.softmax(query.view(B, T, H, -1), dim=-1)
      key = F.softmax(key.view(B, N, H, -1), dim=1)
      # B, N, H, HD
      value = self.value(self.text_norm(xf)).view(B, N, H, -1)
      # B, H, HD, HD
      attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
      y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
      y = x + self.proj_out(y, emb)
      return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant