-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfocal_loss_with_smoothing.py
78 lines (71 loc) · 2.63 KB
/
focal_loss_with_smoothing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLossWithSmoothing(nn.Module):
def __init__(
self,
num_classes: int,
gamma: int = 1,
lb_smooth: float = 0.1,
size_average: bool = True,
ignore_index: int = None,
alpha: float = None):
"""
:param gamma:
:param lb_smooth:
:param ignore_index:
:param size_average:
:param alpha:
"""
super(FocalLossWithSmoothing, self).__init__()
self._num_classes = num_classes
self._gamma = gamma
self._lb_smooth = lb_smooth
self._size_average = size_average
self._ignore_index = ignore_index
self._log_softmax = nn.LogSoftmax(dim=1)
self._alpha = alpha
if self._num_classes <= 1:
raise ValueError('The number of classes must be 2 or higher')
if self._gamma < 0:
raise ValueError('Gamma must be 0 or higher')
if self._alpha is not None:
if self._alpha <= 0 or self._alpha >= 1:
raise ValueError('Alpha must be 0 <= alpha <= 1')
def forward(self, logits, label):
"""
:param logits: (batch_size, class, height, width)
:param label:
:return:
"""
logits = logits.float()
difficulty_level = self._estimate_difficulty_level(logits, label)
with torch.no_grad():
label = label.clone().detach()
if self._ignore_index is not None:
ignore = label.eq(self._ignore_index)
label[ignore] = 0
lb_pos, lb_neg = 1. - self._lb_smooth, self._lb_smooth / (self._num_classes - 1)
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self._log_softmax(logits)
loss = -torch.sum(difficulty_level * logs * lb_one_hot, dim=1)
if self._ignore_index is not None:
loss[ignore] = 0
return loss.mean()
def _estimate_difficulty_level(self, logits, label):
"""
:param logits:
:param label:
:return:
"""
one_hot_key = torch.nn.functional.one_hot(label, num_classes=self._num_classes)
if len(one_hot_key.shape) == 4:
one_hot_key = one_hot_key.permute(0, 3, 1, 2)
if one_hot_key.device != logits.device:
one_hot_key = one_hot_key.to(logits.device)
pt = one_hot_key * F.softmax(logits)
difficulty_level = torch.pow(1 - pt, self._gamma)
return difficulty_level