-
Notifications
You must be signed in to change notification settings - Fork 13
/
fim.py
87 lines (71 loc) · 2.7 KB
/
fim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time
import sys
from typing import Dict
from argparse import Namespace
import torch
from torch import Tensor
from torch.distributions import Categorical
from torch.nn import Module
from torch.utils.data import DataLoader
assert int(torch.__version__.split(".")[1]) >= 4, "PyTorch 0.4+ required"
def fim_diag(model: Module,
data_loader: DataLoader,
samples_no: int = None,
empirical: bool = False,
device: torch.device = None,
verbose: bool = False,
every_n: int = None) -> Dict[int, Dict[str, Tensor]]:
fim = {}
for name, param in model.named_parameters():
if param.requires_grad:
fim[name] = torch.zeros_like(param)
seen_no = 0
last = 0
tic = time.time()
all_fims = dict({})
while samples_no is None or seen_no < samples_no:
data_iterator = iter(data_loader)
try:
data, target = next(data_iterator)
except StopIteration:
if samples_no is None:
break
data_iterator = iter(data_loader)
data, target = next(data_loader)
if device is not None:
data = data.to(device)
if empirical:
target = target.to(device)
logits = model(data)
if empirical:
outdx = target.unsqueeze(1)
else:
outdx = Categorical(logits=logits).sample().unsqueeze(1).detach()
samples = logits.gather(1, outdx)
idx, batch_size = 0, data.size(0)
while idx < batch_size and (samples_no is None or seen_no < samples_no):
model.zero_grad()
torch.autograd.backward(samples[idx], retain_graph=True)
for name, param in model.named_parameters():
if param.requires_grad:
fim[name] += (param.grad * param.grad)
fim[name].detach_()
seen_no += 1
idx += 1
if verbose and seen_no % 100 == 0:
toc = time.time()
fps = float(seen_no - last) / (toc - tic)
tic, last = toc, seen_no
sys.stdout.write(f"\rSamples: {seen_no:5d}. Fps: {fps:2.4f} samples/s.")
if every_n and seen_no % every_n == 0:
all_fims[seen_no] = {n: f.clone().div_(seen_no).detach_()
for (n, f) in fim.items()}
if verbose:
if seen_no > last:
toc = time.time()
fps = float(seen_no - last) / (toc - tic)
sys.stdout.write(f"\rSamples: {seen_no:5d}. Fps: {fps:2.5f} samples/s.\n")
for name, grad2 in fim.items():
grad2 /= float(seen_no)
all_fims[seen_no] = fim
return all_fims