-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
53 lines (38 loc) · 1.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn.functional as F
from dice_loss import DiceLoss
def psnr_mse(input, target):
mse = F.mse_loss(input, target)
psnr = 20.0 * torch.log10(1.0 / torch.sqrt(mse))
return psnr
def psnr_mae(input, target):
mae = F.l1_loss(input, target)
psnr = 20.0 * torch.log10(1.0 / torch.sqrt(mae))
return psnr
def bce_mse(input, target):
bce = F.binary_cross_entropy(input, target)
mse = F.mse_loss(input, target)
return bce + mse
def weighted_bce(input, target):
weight = torch.tensor([2.0]).type_as(input)
bce = F.binary_cross_entropy_with_logits(input, target, pos_weight=weight)
return bce
def dice_loss(input, target):
dice = DiceLoss()
loss = dice(input, target)
return loss
def compute_ts_road_map(road_map1, road_map2):
"""Computes the mean threat score of road images for an entire batch"""
tp = (road_map1 * road_map2).sum(axis=(1, 2))
ts = tp * 1.0 / (road_map1.sum(axis=(1, 2)) + road_map2.sum(axis=(1, 2)) - tp)
return ts.mean()
LOSS = {
"bce": F.binary_cross_entropy_with_logits,
"weighted_bce": weighted_bce,
"mse": F.mse_loss,
"mae": F.l1_loss,
"bce+mse": bce_mse,
"dice_loss": dice_loss,
"psnr_mse": psnr_mse,
"psnr_mae": psnr_mae,
}