You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks a lot for your paper and code!
In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.
def encode_text(self, text, device):
with torch.no_grad():
text = clip.tokenize(text, truncate=True).to(device)
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim]
x = x + self.clip.positional_embedding.type(self.clip.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.transformer(x)
x = self.clip.ln_final(x).type(self.clip.dtype)
# T, B, D
x = self.text_pre_proj(x) xf_out = self.textTransEncoder(x)
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out
class LinearTemporalCrossAttention(nn.Module):\
def __init__(self, seq_len, latent_dim, text_latent_dim, num_head, dropout, time_embed_dim):
super().__init__()
self.num_head = num_head
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
self.query = nn.Linear(latent_dim, latent_dim)
self.key = nn.Linear(text_latent_dim, latent_dim)
self.value = nn.Linear(text_latent_dim, latent_dim)
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x, xf, emb):
"""
x: B, T, D
xf: B, N, L
"""
B, T, D = x.shape
N = xf.shape[1]
H = self.num_head
# B, T, D
query = self.query(self.norm(x))
# B, N, D
key = self.key(self.text_norm(xf))
query = F.softmax(query.view(B, T, H, -1), dim=-1)
key = F.softmax(key.view(B, N, H, -1), dim=1)
# B, N, H, HD
value = self.value(self.text_norm(xf)).view(B, N, H, -1)
# B, H, HD, HD
attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
y = x + self.proj_out(y, emb)
return y
The text was updated successfully, but these errors were encountered:
Thanks a lot for your paper and code!
In your implementation, you didn't set attention mask for text sequence both in textTransformer layers and LinearTemporalCrossAttention layers, why it didn't cause any influence? Below is the related code.
def encode_text(self, text, device):
with torch.no_grad():
text = clip.tokenize(text, truncate=True).to(device)
x = self.clip.token_embedding(text).type(self.clip.dtype) # [batch_size, n_ctx, latent_dim]
x = x + self.clip.positional_embedding.type(self.clip.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.clip.transformer(x)
x = self.clip.ln_final(x).type(self.clip.dtype)
# T, B, D
x = self.text_pre_proj(x)
xf_out = self.textTransEncoder(x)
xf_out = self.text_ln(xf_out)
xf_proj = self.text_proj(xf_out[text.argmax(dim=-1), torch.arange(xf_out.shape[1])])
# B, T, D
xf_out = xf_out.permute(1, 0, 2)
return xf_proj, xf_out
class LinearTemporalCrossAttention(nn.Module):\
The text was updated successfully, but these errors were encountered: