-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcriteria.py
95 lines (79 loc) · 3.16 KB
/
criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn
from utils.util import MAX_8BIT
loss_names = ['l1', 'l2', 'l1l2', 'l1c', 'l2c', 'l1l2c']
def compute_valid_mask(target):
return (target > 0).detach()
def compute_diff(pred, target, valid_mask):
diff = target - pred
return diff[valid_mask]
class MaskedMSELoss(nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()
def forward(self, pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
diff = compute_diff(pred, target, valid_mask)
self.loss = (diff**2).mean()
return self.loss
def L1(pred, target):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
diff = compute_diff(pred, target, valid_mask)
loss = diff.abs().mean()
return loss
class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()
def forward(self, pred, target, weight=None):
if isinstance(pred, list):
loss = sum(L1(p, target) for p in pred)
else:
loss = L1(pred, target)
return loss
class MaskedL1L2Loss(nn.Module):
def __init__(self):
super(MaskedL1L2Loss, self).__init__()
def forward(self, pred, target, weight=None):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
diff = compute_diff(pred, target, valid_mask)
l1 = diff.abs().mean()
l2 = (diff**2).mean()
self.loss = l1 + l2
return self.loss
class UncertaintyL1Loss(nn.Module):
def __init__(self):
super(UncertaintyL1Loss, self).__init__()
def forward(self, pred, target, conf_inv, conf_lambda):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
conf_inv = conf_inv * conf_lambda
diff = torch.sqrt((target - pred)**2 / (conf_inv**2) + 4*torch.log1p(conf_inv))
diff = diff[valid_mask]
self.loss = diff.abs().mean()
return self.loss
class UncertaintyL2Loss(nn.Module):
def __init__(self):
super(UncertaintyL2Loss, self).__init__()
def forward(self, pred, target, conf_inv, conf_lambda):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
conf_inv = conf_inv * conf_lambda
diff = (target - pred)**2 / (conf_inv**2) + 4*torch.log1p(conf_inv)
diff = diff[valid_mask]
self.loss = diff.abs().mean()
return self.loss
class UncertaintyL1L2Loss(nn.Module):
def __init__(self):
super(UncertaintyL1L2Loss, self).__init__()
def forward(self, pred, target, conf_inv, conf_lambda):
assert pred.dim() == target.dim(), "inconsistent dimensions"
valid_mask = compute_valid_mask(target)
conf_inv = conf_inv * conf_lambda
l2 = (target - pred)**2 / (conf_inv**2) + 4*torch.log1p(conf_inv)
l1 = torch.sqrt(l2)
l2 = l2[valid_mask].abs().mean()
l1 = l1[valid_mask].abs().mean()
self.loss = l1 + l2
return self.loss