-
Notifications
You must be signed in to change notification settings - Fork 1
/
regularizers.py
76 lines (62 loc) · 3.14 KB
/
regularizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
from torch import nn as nn
from torch.nn import functional as F
class Normalizer(nn.Module):
def __init__(self, mean, std):
super().__init__()
self.register_buffer('mean', torch.Tensor(mean).reshape((1, -1, 1, 1)))
self.register_buffer('std', torch.Tensor(std).reshape((1, -1, 1, 1)))
def forward(self, t: torch.tensor) -> torch.tensor:
return self.get_normal(t)
def get_normal(self, t: torch.Tensor) -> torch.Tensor:
return (t - self.mean) / self.std
def get_unit(self, t: torch.Tensor) -> torch.Tensor:
return (t * self.std) + self.mean
class TotalVariation(nn.Module):
def __init__(self, p: int = 2):
super().__init__()
self.p = p
def forward(self, x: torch.tensor) -> torch.tensor:
x_wise = x[:, :, :, 1:] - x[:, :, :, :-1]
y_wise = x[:, :, 1:, :] - x[:, :, :-1, :]
diag_1 = x[:, :, 1:, 1:] - x[:, :, :-1, :-1]
diag_2 = x[:, :, 1:, :-1] - x[:, :, :-1, 1:]
return x_wise.norm(p=self.p, dim=(2, 3)).mean() + y_wise.norm(p=self.p, dim=(2, 3)).mean() + \
diag_1.norm(p=self.p, dim=(2, 3)).mean() + diag_2.norm(p=self.p, dim=(2, 3)).mean()
class NormalVariation(TotalVariation):
def forward(self, x: torch.tensor, per_sample: bool = True) -> torch.tensor:
std = x.std() if not per_sample else x.view(x.shape[0], -1).std(dim=-1).view(-1, 1, 1, 1)
x = (x - x.mean()) / (std + 0.0001)
return super(NormalVariation, self).forward(x)
class ColorVariation(TotalVariation):
def forward(self, x: torch.tensor) -> torch.tensor:
rolled = x.roll(shifts=1, dims=-3)
return super(ColorVariation, self).forward(x - rolled)
class L1Norm(nn.Module):
def forward(self, x: torch.tensor) -> torch.tensor:
return x.norm(p=1, dim=(1, 2, 3)).mean()
class L2Norm(nn.Module):
def forward(self, x: torch.tensor) -> torch.tensor:
return x.norm(p=2, dim=(1, 2, 3)).mean()
class FakeColorDistribution(nn.Module):
def __init__(self, normalizer: Normalizer):
super().__init__()
self.normalizer = normalizer
def forward(self, x: torch.tensor) -> torch.tensor:
view = x.transpose(1, 0).contiguous().view([x.patch_size(1), -1])
mean, std = view.mean(-1), view.std(-1, unbiased=False)
mean_loss = (mean.view(-1) - self.normalizer.mean.view(-1)).norm()
std_loss = (std.view(-1) - self.normalizer.std.view(-1)).norm()
return mean_loss + std_loss
class FakeBatchNorm(nn.Module):
def __init__(self, resnet_function, normalizer: Normalizer):
super().__init__()
resnet = resnet_function(pretrained=True)
self.conv, self.bn = resnet.conv1, resnet.bn1
self.normalizer = normalizer
def forward(self, x: torch.tensor) -> torch.tensor:
x = self.conv(self.normalizer(x))
view = x.transpose(1, 0).contiguous().view([x.patch_size(1), -1])
mean, var = view.mean(1), view.var(1, unbiased=False)
loss = torch.norm(self.bn.running_var.data - var, 2) + torch.norm(self.bn.running_mean.data - mean, 2)
return loss