diff --git a/wenet/paraformer/cif.py b/wenet/paraformer/cif.py index 5ee7c342a8..07a23d3f32 100644 --- a/wenet/paraformer/cif.py +++ b/wenet/paraformer/cif.py @@ -153,7 +153,7 @@ def gen_frame_alignments(self, else: token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type) - max_token_num = torch.max(token_num).item() + max_token_num = torch.max(token_num) alphas_cumsum = torch.cumsum(alphas, dim=1) alphas_cumsum = torch.floor(alphas_cumsum).type(int_type) diff --git a/wenet/utils/mask.py b/wenet/utils/mask.py index 0480fb4f6a..7a4d0940dc 100644 --- a/wenet/utils/mask.py +++ b/wenet/utils/mask.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union import torch ''' def subsequent_mask( @@ -197,7 +198,9 @@ def add_optional_chunk_mask(xs: torch.Tensor, return chunk_masks -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: +def make_pad_mask( + lengths: torch.Tensor, + max_len: Optional[Union[torch.Tensor, int]] = None) -> torch.Tensor: """Make mask tensor containing indices of padded part. See description of make_non_pad_mask. @@ -215,7 +218,16 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: [0, 0, 1, 1, 1]] """ batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() + if max_len is None: + max_len = torch.max(lengths) + else: + if isinstance(max_len, int): + max_len = torch.tensor(max_len, + dtype=lengths.dtype, + device=lengths.device) + else: + assert isinstance(max_len, torch.Tensor) + seq_range = torch.arange(0, max_len, dtype=torch.int64, @@ -226,7 +238,8 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: return mask -def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: +def make_non_pad_mask(lengths: torch.Tensor, + max_len: Optional[torch.Tensor] = None) -> torch.Tensor: """Make mask tensor containing indices of non-padded part. The sequences in a batch may have different lengths. To enable @@ -251,7 +264,7 @@ def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor: [1, 1, 1, 0, 0], [1, 1, 0, 0, 0]] """ - return ~make_pad_mask(lengths) + return ~make_pad_mask(lengths, max_len) def mask_finished_scores(score: torch.Tensor,