Skip to content

Commit

Permalink
Update rpr.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kjysmu authored Nov 3, 2023
1 parent 5748668 commit 7735567
Showing 1 changed file with 1 addition and 30 deletions.
31 changes: 1 addition & 30 deletions model/rpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@
from torch.nn.init import *

from torch.nn.functional import linear, softmax, dropout

from torch.nn import MultiheadAttention

from typing import Optional


class TransformerDecoderRPR(Module):
def __init__(self, decoder_layer, num_layers, norm=None):
super(TransformerDecoderRPR, self).__init__()
Expand All @@ -25,7 +22,6 @@ def __init__(self, decoder_layer, num_layers, norm=None):
self.norm = norm

def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):

output = tgt
for mod in self.layers:
output = mod(output, memory, tgt_mask=tgt_mask,
Expand All @@ -38,7 +34,6 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_

return output


class TransformerDecoderLayerRPR(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
super(TransformerDecoderLayerRPR, self).__init__()
Expand Down Expand Up @@ -74,7 +69,6 @@ def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
tgt = self.norm3(tgt)
return tgt


# TransformerEncoderRPR (only for music transformer)
class TransformerEncoderRPR(Module):
def __init__(self, encoder_layer, num_layers, norm=None):
Expand All @@ -90,6 +84,7 @@ def forward(self, src, mask=None, src_key_padding_mask=None):
if self.norm:
output = self.norm(output)
return output

# TransformerEncoderLayerRPR (only for music transformer)
class TransformerEncoderLayerRPR(Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, er_len=None):
Expand All @@ -113,7 +108,6 @@ def forward(self, src, src_mask=None, src_key_padding_mask=None):
src = self.norm2(src)
return src


# MultiheadAttentionRPR
class MultiheadAttentionRPR(Module):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, er_len=None):
Expand Down Expand Up @@ -177,16 +171,6 @@ def forward(self, query, key, value, key_padding_mask=None,
need_weights=True, attn_mask=None):

if hasattr(self, '_qkv_same_embed_dim') and self._qkv_same_embed_dim is False:
# return F.multi_head_attention_forward(
# query, key, value, self.embed_dim, self.num_heads,
# self.in_proj_weight, self.in_proj_bias,
# self.bias_k, self.bias_v, self.add_zero_attn,
# self.dropout, self.out_proj.weight, self.out_proj.bias,
# training=self.training,
# key_padding_mask=key_padding_mask, need_weights=need_weights,
# attn_mask=attn_mask, use_separate_proj_weight=True,
# q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
# v_proj_weight=self.v_proj_weight)

return multi_head_attention_forward_rpr(
query, key, value, self.embed_dim, self.num_heads,
Expand All @@ -204,15 +188,6 @@ def forward(self, query, key, value, key_padding_mask=None,
Please re-train your model with the new module',
UserWarning)

# return F.multi_head_attention_forward(
# query, key, value, self.embed_dim, self.num_heads,
# self.in_proj_weight, self.in_proj_bias,
# self.bias_k, self.bias_v, self.add_zero_attn,
# self.dropout, self.out_proj.weight, self.out_proj.bias,
# training=self.training,
# key_padding_mask=key_padding_mask, need_weights=need_weights,
# attn_mask=attn_mask)

return multi_head_attention_forward_rpr(
query, key, value, self.embed_dim, self.num_heads,
self.in_proj_weight, self.in_proj_bias,
Expand Down Expand Up @@ -255,11 +230,9 @@ def multi_head_attention_forward_rpr(query, # type: Tensor
----------
For Relative Position Representation support (https://arxiv.org/abs/1803.02155)
https://pytorch.org/docs/1.2.0/_modules/torch/nn/functional.html
Modification to take RPR embedding matrix and perform skew optimized RPR (https://arxiv.org/abs/1809.04281)
----------
"""

# type: (...) -> Tuple[Tensor, Optional[Tensor]]

qkv_same = torch.equal(query, key) and torch.equal(key, value)
Expand Down Expand Up @@ -295,7 +268,6 @@ def multi_head_attention_forward_rpr(query, # type: Tensor
k = None
v = None
else:

# This is inline in_proj function with in_proj_weight and in_proj_bias
_b = in_proj_bias
_start = embed_dim
Expand Down Expand Up @@ -472,7 +444,6 @@ def _skew(qe):
Performs the skew optimized RPR computation (https://arxiv.org/abs/1809.04281)
----------
"""

sz = qe.shape[1]
mask = (torch.triu(torch.ones(sz, sz).to(qe.device)) == 1).float().flip(0)

Expand Down

0 comments on commit 7735567

Please sign in to comment.