From 8ad2ae54382ae81b72616abd62f813347fb1db83 Mon Sep 17 00:00:00 2001 From: kaijieshi <1638961687@qq.com> Date: Wed, 24 Feb 2021 13:02:55 +0800 Subject: [PATCH] Update label_smooth.py lb_pos + lb_neg should equal to 1, lb_pos = 1-lb_smooth. lb_neg should be lb_smooth/(num_class-1) instead of lb_smooth/num_class. --- label_smooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/label_smooth.py b/label_smooth.py index 86e9878..22815dd 100644 --- a/label_smooth.py +++ b/label_smooth.py @@ -36,7 +36,7 @@ def forward(self, logits, label): ignore = label.eq(self.lb_ignore) n_valid = ignore.eq(0).sum() label[ignore] = 0 - lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes + lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / (num_classes - 1) lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() @@ -61,7 +61,7 @@ class LSRCrossEntropyFunctionV2(torch.autograd.Function): def forward(ctx, logits, label, lb_smooth, lb_ignore): # prepare label num_classes = logits.size(1) - lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / num_classes + lb_pos, lb_neg = 1. - lb_smooth, lb_smooth / (num_classes - 1) label = label.clone().detach() ignore = label.eq(lb_ignore) n_valid = ignore.eq(0).sum()