-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathllr.py
80 lines (65 loc) · 2.48 KB
/
llr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""
https://arxiv.org/pdf/1907.02610.pdf
"""
import torch
import numpy as np
import torch.optim as optim
def locally_linearity_regularization(model,
loss_fn, x, y, norm, optimizer,
step_size, epsilon=0.031, perturb_steps=10,
lambd=4.0, mu=3.0, version=None):
model.eval()
batch_size = len(x)
def model_grad(x):
x.requires_grad_(True)
lx = loss_fn(model(x), y)
lx.backward()
ret = x.grad
x.grad.zero_()
x.requires_grad_(False)
return ret
def grad_dot(x, delta, model_grad):
ret = torch.matmul(model_grad.flatten(start_dim=1), delta.flatten(start_dim=1).T)
return torch.mean(ret)
# calc gamma(eps, x)
def g(x, delta: torch.Tensor, model_grad):
ret = loss_fn(model(x+delta), y) - grad_dot(x, delta, model_grad)
#ret = loss_fn(model(x+delta), y) - loss_fn(model(x), y) - grad_dot(x, delta)
return torch.abs(ret)
mg = model_grad(x)
if norm in [2, np.inf]:
delta = 0.001 * torch.randn(x.shape).cuda().detach()
delta = torch.autograd.Variable(delta.data, requires_grad=True)
# Setup optimizers
optimizer_delta = optim.SGD([delta], lr=step_size)
for _ in range(perturb_steps):
# optimize
optimizer_delta.zero_grad()
loss = (-1) * g(x, delta, mg)
loss.backward()
# renorming gradient
grad_norms = delta.grad.view(batch_size, -1).norm(p=norm, dim=1)
delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
# avoid nan or inf if gradient is 0
if (grad_norms == 0).any():
delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
optimizer_delta.step()
# projection
delta.data.add_(x)
delta.data.clamp_(0, 1).sub_(x)
delta.data.renorm_(p=norm, dim=0, maxnorm=epsilon)
delta.requires_grad_(False)
#x_adv = Variable(x_natural + delta, requires_grad=False)
else:
raise ValueError(f"[LLR] Not supported norm: {norm}")
model.train()
# zero gradient
optimizer.zero_grad()
# calculate robust loss
outputs = model(x)
loss_natural = loss_fn(outputs, y)
if version == "sum":
loss = loss_natural + lambd * g(x, delta, mg) + mu * grad_dot(x, delta, mg) * len(x)
else:
loss = loss_natural + lambd * g(x, delta, mg) + mu * grad_dot(x, delta, mg)
return outputs, loss