diff --git a/label_smooth.py b/label_smooth.py index 86e9878..22815dd 100644 --- a/label_smooth.py +++ b/label_smooth.py @@ -36,7 +36,7 @@ def forward(self, logits, label): ignore = label.eq(self.lb_ignore) n_valid = ignore.eq(0).sum() label[ignore] = 0 - lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes + lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / (num_classes - 1) lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() @@ -61,7 +61,7 @@ class LSRCrossEntropyFunctionV2(torch.autograd.Function): def forward(ctx, logits, label, lb_smooth, lb_ignore): # prepare label num_classes = logits.size(1) - lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes + lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / (num_classes - 1) label = label.clone().detach() ignore = label.eq(lb_ignore) n_valid = ignore.eq(0).sum()