Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[firred] bring fired aed to wenet #2680

Merged
merged 12 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 186 additions & 0 deletions wenet/firered/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) 2025 Wenet Community. authors: Mddct(Dinghao Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional, Tuple, Union

import torch
from wenet.transformer.attention import (T_CACHE,
RelPositionMultiHeadedAttention)
from wenet.transformer.embedding import PositionalEncoding


class FireRedRelPositionalEncoding(PositionalEncoding):

def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):

super().__init__(d_model, dropout_rate, max_len)
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
-(torch.log(torch.tensor(10000.0)).item() / d_model))
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)

pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.register_buffer('pe', pe)

def position_encoding(self,
offset: Union[int, torch.Tensor],
size: int,
apply_dropout: bool = True) -> torch.Tensor:

raise NotImplementedError('firedasr not support streaming pos encding')

def forward(self, x, offset=None):
Tmax, T = self.pe.size(1), x.size(1)
pos_emb = self.pe[:, Tmax // 2 - T + 1:Tmax // 2 + T].clone().detach()
return self.dropout(x), self.dropout(pos_emb)


class FiredRelPositionMultiHeadedAttention(RelPositionMultiHeadedAttention):
"""Multi-Head Attention layer with relative position encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): The number of heads.
n_feat (int): The number of features.
dropout_rate (float): Dropout rate.
"""

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
query_bias: bool = True,
key_bias: bool = True,
value_bias: bool = True,
use_sdpa: bool = False,
n_kv_head: Optional[int] = None,
head_dim: Optional[int] = None):
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(n_head, n_feat, dropout_rate, query_bias, key_bias,
value_bias, use_sdpa, n_kv_head, head_dim)

self.layer_norm_q = torch.nn.LayerNorm(n_feat)
self.layer_norm_k = torch.nn.LayerNorm(n_feat)
self.layer_norm_v = torch.nn.LayerNorm(n_feat)

def rel_shift(self, x):
"""Compute relative positinal encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, size).
Returns:
torch.Tensor: Output tensor.
"""

zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
device=x.device,
dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=-1)

x_padded = x_padded.view(x.size()[0],
x.size()[1],
x.size(3) + 1, x.size(2))
x = x_padded[:, :, 1:].view_as(x)
x = x[:, :, :, :x.size(-1) // 2 + 1]

return x

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
query = self.layer_norm_q(query)
key = self.layer_norm_k(key)
value = self.layer_norm_v(value)
return super().forward_qkv(query, key, value)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: T_CACHE = (torch.zeros((0, 0, 0, 0)), torch.zeros((0, 0, 0, 0)))
) -> Tuple[torch.Tensor, T_CACHE]:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, time2, size).
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
where `cache_t == chunk_size * num_decoding_left_chunks`
and `head * d_k == size`
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
k, v, new_cache = self._update_kv_and_cache(k, v, cache)

n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)

# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)

# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
if not self.use_sdpa:
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

scores = (matrix_ac + matrix_bd) / math.sqrt(
self.d_k) # (batch, head, time1, time2)

return self.forward_attention(v, scores, mask), new_cache
else:
# NOTE(Mddct): we need mask bias, not boolean mask
assert mask.dtype != torch.bool
mask = mask.unsqueeze(1)
# matrix_bd as a mask bias
mask = (matrix_bd + mask) / math.sqrt(self.d_k)
output = torch.nn.functional.scaled_dot_product_attention(
q_with_bias_u,
k,
v,
attn_mask=mask,
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache
Loading
Loading