diff --git a/egs/librispeech/SSL/zipformer/wav2vec2_module.py b/egs/librispeech/SSL/zipformer/wav2vec2_module.py index ab5ca005f4..ccaaccf46f 100644 --- a/egs/librispeech/SSL/zipformer/wav2vec2_module.py +++ b/egs/librispeech/SSL/zipformer/wav2vec2_module.py @@ -22,12 +22,15 @@ from typing import List, Tuple import numpy as np +import random +from scaling import penalize_abs_values_gt import torch import torch.nn as nn import torch.nn.functional as F from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast + class ConvFeatureExtractionModel(nn.Module): def __init__( self, @@ -105,4 +108,8 @@ def forward(self, x): for conv in self.conv_layers: x = conv(x) + if self.training and random.random() < 0.2: + x = penalize_abs_values_gt(x, limit=1000.0, penalty=1.0e-05, + name=(self.name if hasattr(self, 'name') else 'ConvFeatureExtractionModel')) + return x diff --git a/egs/librispeech/SSL/zipformer/zipformer.py b/egs/librispeech/SSL/zipformer/zipformer.py index e9eff3357e..0f9d03ae22 100644 --- a/egs/librispeech/SSL/zipformer/zipformer.py +++ b/egs/librispeech/SSL/zipformer/zipformer.py @@ -789,7 +789,7 @@ def forward( selected_attn_weights = attn_weights[0:1] if torch.jit.is_scripting() or torch.jit.is_tracing(): pass - elif not self.training and random.random() < float(self.const_attention_rate): + elif self.training and random.random() < float(self.const_attention_rate): # Make attention weights constant. The intention is to # encourage these modules to do something similar to an # averaging-over-time operation.