forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheql.py
94 lines (77 loc) · 3.12 KB
/
eql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
This code is based on the following file:
https://github.com/tztztztztz/eqlv2/blob/master/mmdet/models/losses/eql.py
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
def get_image_count_frequency(version="v0_5"):
if version == "v0_5":
from mmdet.utils.lvis_v0_5_categories import get_image_count_frequency
return get_image_count_frequency()
elif version == "v1":
from mmdet.utils.lvis_v1_0_categories import get_image_count_frequency
return get_image_count_frequency()
elif version == "openimage":
from mmdet.utils.openimage_categories import get_instance_count
return get_instance_count()
elif version == "NDL":
from mmdet.utils.ndl_categories import get_instance_count
return get_instance_count()
else:
raise KeyError(f"version {version} is not supported")
@LOSSES.register_module()
class EQL(nn.Module):
def __init__(self,
use_sigmoid=True,
reduction='mean',
class_weight=None,
loss_weight=1.0,
lambda_=0.00177,
version="v0_5"):
super(EQL, self).__init__()
self.use_sigmoid = use_sigmoid
self.reduction = reduction
self.loss_weight = loss_weight
self.class_weight = class_weight
self.lambda_ = lambda_
self.version = version
self.freq_info = torch.FloatTensor(get_image_count_frequency(version))
num_class_included = torch.sum(self.freq_info < self.lambda_)
print(f"set up EQL (version {version}), {num_class_included} classes included.")
def forward(self,
cls_score,
label,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
self.n_i, self.n_c = cls_score.size()
self.gt_classes = label
self.pred_class_logits = cls_score
def expand_label(pred, gt_classes):
target = pred.new_zeros(self.n_i, self.n_c + 1)
target[torch.arange(self.n_i), gt_classes] = 1
return target[:, :self.n_c]
target = expand_label(cls_score, label)
eql_w = 1 - self.exclude_func() * self.threshold_func() * (1 - target)
cls_loss = F.binary_cross_entropy_with_logits(cls_score, target,
reduction='none')
cls_loss = torch.sum(cls_loss * eql_w) / self.n_i
return self.loss_weight * cls_loss
def exclude_func(self):
# instance-level weight
bg_ind = self.n_c
weight = (self.gt_classes != bg_ind).float()
weight = weight.view(self.n_i, 1).expand(self.n_i, self.n_c)
return weight
def threshold_func(self):
# class-level weight
weight = self.pred_class_logits.new_zeros(self.n_c)
# weight[self.freq_info < self.lambda_] = 1
for i in range(len(weight)):
if self.freq_info[i] < self.lambda_:
weight[i] = 1
weight = weight.view(1, self.n_c).expand(self.n_i, self.n_c)
return weight