-
Notifications
You must be signed in to change notification settings - Fork 0
/
loss.py
63 lines (49 loc) · 1.74 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
def vector_diffs(lines):
"""
Vector regularization.
----
Input:
- lines: List of vectors(1D Tensor).
"""
total = 0
for idx in range(len(lines)):
n_comp, n_size = lines[idx].shape[1:-1]
dotp = torch.matmul(lines[idx].view(n_comp,n_size), lines[idx].view(n_comp,n_size).transpose(-1,-2)) # Covariance matrix
non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1] # Extract non diagonal elements
total = total + torch.mean(torch.abs(non_diagonal))
return total
def L1_VM(planes, lines):
"""
L1 regularization for VM decompositon.
----
Input:
- planes: List of matrices(2D Tensor). VM decomposition planes.
- lines: List of vectors(1D Tensor). VM decomposition lines.
"""
total = 0
for idx in range(len(planes)):
total = total + torch.mean(torch.abs(planes[idx])) + torch.mean(torch.abs(lines[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.planes[idx]))
return total
def _tensor_size(t):
return t.size()[1]*t.size()[2]*t.size()[3]
def _TVloss(x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = _tensor_size(x[:,:,1:,:])
count_w = _tensor_size(x[:,:,:,1:])
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return 2*(h_tv/count_h+w_tv/count_w)/batch_size
def TVloss(planes, factor=1e-2):
"""
Total variation loss.
----
Input:
- planes: List of VM planes(Tensor[1, R, h, w])
"""
total = 0
for idx in range(len(planes)):
total = total + _TVloss(planes[idx]) * factor #+ reg(self.density_line[idx]) * 1e-3
return total