From 7735567ef9327b103e2dd2862b9ad662408ffede Mon Sep 17 00:00:00 2001 From: Jaeyong Kang Date: Fri, 3 Nov 2023 09:10:55 +0800 Subject: [PATCH] Update rpr.py --- model/rpr.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/model/rpr.py b/model/rpr.py index d5f5cab5..15734517 100644 --- a/model/rpr.py +++ b/model/rpr.py @@ -11,12 +11,9 @@ from torch.nn.init import * from torch.nn.functional import linear, softmax, dropout - from torch.nn import MultiheadAttention - from typing import Optional - class TransformerDecoderRPR(Module): def __init__(self, decoder_layer, num_layers, norm=None): super(TransformerDecoderRPR, self).__init__() @@ -25,7 +22,6 @@ def __init__(self, decoder_layer, num_layers, norm=None): self.norm = norm def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None): - output = tgt for mod in self.layers: output = mod(output, memory, tgt_mask=tgt_mask, @@ -38,7 +34,6 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_ return output - class TransformerDecoderLayerRPR(Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): super(TransformerDecoderLayerRPR, self).__init__() @@ -74,7 +69,6 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt = self.norm3(tgt) return tgt - # TransformerEncoderRPR (only for music transformer) class TransformerEncoderRPR(Module): def __init__(self, encoder_layer, num_layers, norm=None): @@ -90,6 +84,7 @@ def forward(self, src, mask=None, src_key_padding_mask=None): if self.norm: output = self.norm(output) return output + # TransformerEncoderLayerRPR (only for music transformer) class TransformerEncoderLayerRPR(Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None): @@ -113,7 +108,6 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None): src = self.norm2(src) return src - # MultiheadAttentionRPR class MultiheadAttentionRPR(Module): def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None): @@ -177,16 +171,6 @@ def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False: - # return F.multi_head_attention_forward( - # query, key, value, self.embed_dim, self.num_heads, - # self.in_proj_weight, self.in_proj_bias, - # self.bias_k, self.bias_v, self.add_zero_attn, - # self.dropout, self.out_proj.weight, self.out_proj.bias, - # training=self.training, - # key_padding_mask=key_padding_mask, need_weights=need_weights, - # attn_mask=attn_mask, use_separate_proj_weight=True, - # q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, - # v_proj_weight=self.v_proj_weight) return multi_head_attention_forward_rpr( query, key, value, self.embed_dim, self.num_heads, @@ -204,15 +188,6 @@ def forward(self, query, key, value, key_padding_mask=None, Please re-train your model with the new module', UserWarning) - # return F.multi_head_attention_forward( - # query, key, value, self.embed_dim, self.num_heads, - # self.in_proj_weight, self.in_proj_bias, - # self.bias_k, self.bias_v, self.add_zero_attn, - # self.dropout, self.out_proj.weight, self.out_proj.bias, - # training=self.training, - # key_padding_mask=key_padding_mask, need_weights=need_weights, - # attn_mask=attn_mask) - return multi_head_attention_forward_rpr( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, @@ -255,11 +230,9 @@ def multi_head_attention_forward_rpr(query, # type: Tensor ---------- For Relative Position Representation support (https://arxiv.org/abs/1803.02155) https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html - Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281) ---------- """ - # type: (...) -> Tuple[Tensor, Optional[Tensor]] qkv_same = torch.equal(query, key) and torch.equal(key, value) @@ -295,7 +268,6 @@ def multi_head_attention_forward_rpr(query, # type: Tensor k = None v = None else: - # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim @@ -472,7 +444,6 @@ def _skew(qe): Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281) ---------- """ - sz = qe.shape[1] mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)