-
Notifications
You must be signed in to change notification settings - Fork 74
/
loss.py
37 lines (32 loc) · 1.25 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn
import torch
import torch.nn.functional as F
def flatten(tensor):
"""Flattens a given tensor such that the channel axis is first.
The shapes are transformed as follows:
(N, C, D, H, W) -> (C, N * D * H * W)
"""
C = tensor.size(1)
# new axis order
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
# Transpose: (N, C, D, H, W) -> (C, N, D, H, W)
transposed = tensor.permute(axis_order)
# Flatten: (C, N, D, H, W) -> (C, N * D * H * W)
return transposed.contiguous().view(C, -1)
class DiceLoss(nn.Module):
def __init__(self):
super().__init__()
self.epsilon = 1e-5
def forward(self, output, target):
assert output.size() == target.size(), "'input' and 'target' must have the same shape"
output = F.softmax(output, dim=1)
output = flatten(output)
target = flatten(target)
# intersect = (output * target).sum(-1).sum() + self.epsilon
# denominator = ((output + target).sum(-1)).sum() + self.epsilon
intersect = (output * target).sum(-1)
denominator = (output + target).sum(-1)
dice = intersect / denominator
dice = torch.mean(dice)
return 1 - dice
# return 1 - 2. * intersect / denominator