-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattach_hooks.py
147 lines (135 loc) · 7.92 KB
/
attach_hooks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from enum import Enum
from torch import nn
import torch
import pickle
import copy
Modes = Enum('Modes', ['REGION', 'KSECTION', 'BOUNDARY', 'STRONG'])
Act_fct = Enum('Act_fct', ['GELU', 'TANH'])
hooks = []
activations = {}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Creating a hook
def create_hook(name, out, modes, model_type, activation, k=10, batch_size=4):
def hook(module, in_tensor, out_tensor):
out_tensor = copy.deepcopy(out_tensor).to(device)
# Only use this mode with training data
if (model_type == "albef" and out_tensor.shape == torch.Size([batch_size, 1, 768])) or model_type=="vilt":
if Modes.REGION in modes:
# initialize boundary tensors
if not "low" in out:
out["low"] = {}
out["high"] = {}
if not name in out["low"]:
out["low"][name] = torch.min(out_tensor, dim=0).values.to(device)
out["high"][name] = torch.max(out_tensor, dim=0).values.to(device)
# use min, max to get boundary values in each run after initialization
else:
out["low"][name] =torch.min(out["low"][name], torch.min(out_tensor, dim=0).values)
out["high"][name] = torch.max(out["high"][name], torch.max(out_tensor, dim=0).values)
# Computes K-Section Coverage
if Modes.KSECTION in modes:
mode = Modes.KSECTION
try:
assert "low" in out
assert name in out["low"]
out["low"][name] = out["low"][name].to(device)
out["high"][name] = out["high"][name].to(device)
except:
raise Exception('Calculate boundary dicts low and max first')
step_size = (out["high"][name]-out["low"][name])/k
for step in range(k):
# initialize k-section tensors
if not mode.name in out:
out[mode.name] = {}
if not name in out[mode.name]:
out[mode.name][name] = {}
if not step in out[mode.name][name]:
out[mode.name][name][step] = torch.zeros(out_tensor.shape[1:]).to(device)
# if neuron is in range between the k-1th and kth section, add +1 to the kth section tensor
out[mode.name][name][step] = out[mode.name][name][step].add(
torch.sum(
torch.logical_and(
torch.ge(out_tensor, out["low"][name]+(step*step_size)).to(device),
torch.lt(out_tensor, out["low"][name]+((step+1)*step_size)).to(device))
.long(), axis=0))
# Computes Boundary Coverage
if Modes.BOUNDARY in modes:
mode = Modes.BOUNDARY
try:
assert "low" in out
assert name in out["low"]
#out["low"][name] = out["low"][name].to(device)
#out["high"][name] = out["high"][name].to(device)
except:
raise Exception('Calculate boundary dicts low and max first')
# initialize strong/weak tensors
if not mode.name in out:
out[mode.name] = {}
if not name in out[mode.name]:
out[mode.name][name] = {}
if not "strong" in out[mode.name][name]:
out[mode.name][name]["strong"] = torch.zeros(out_tensor.shape[1:]).to(device)
out[mode.name][name]["weak"] = torch.zeros(out_tensor.shape[1:]).to(device)
# if neuron is above max or below low boundaries, add +1 to the respective strong/weak tensor
out[mode.name][name]["strong"] = out[mode.name][name]["strong"].add(
torch.sum(
torch.gt(out_tensor, out["high"][name]), axis=0))
out[mode.name][name]["weak"] = out[mode.name][name]["weak"].add(
torch.sum(
torch.lt(out_tensor, out["low"][name]), axis=0))
return hook
def read_activations(activations_file):
global activations
if activations_file is not None:
with open(activations_file, "rb") as fp:
activations = pickle.load(fp)
return activations
def get_activations():
return activations
# Loop through all layers of the model and choose the ones with an activaton function to register a hook
def get_all_activation_layers(net, modes, model_type, k=10, batch_size=4):
if model_type == "vilt":
for name, layer in net._modules.items():
if name == "vilt":
n3, l3 = list(list(layer._modules.items())[3][1]._modules.items())[1]
l3.register_forward_hook(create_hook(n3, activations, modes, model_type, Act_fct.TANH, k, batch_size))
else:
n4, l4 = list(layer._modules.items())[2]
l4.register_forward_hook(create_hook(n4, activations, modes, model_type, Act_fct.GELU, k, batch_size))
elif model_type == "albef":
mod = net.text_decoder.cls.predictions.transform.transform_act_fn
mod.register_forward_hook(
create_hook("text_decoder", activations, modes, model_type, Act_fct.GELU, k, batch_size))
def compute_ksection_coverage(activations):
if not "COVERAGE" in activations:
activations["COVERAGE"] = {}
if not "KSECTION" in activations["COVERAGE"]:
activations["COVERAGE"]["KSECTION"] = {}
for key in activations["KSECTION"].keys():
activations["COVERAGE"]["KSECTION"][key] = 0
for i in list(activations["KSECTION"][key].keys()):
activations['KSECTION'][key][i] = activations['KSECTION'][key][i].squeeze(0)
activations["COVERAGE"]["KSECTION"][key] += (torch.sum(activations['KSECTION'][key][i] > 0)/len(activations['KSECTION'][key][i]))
k_sec_coverage = (activations["COVERAGE"]["KSECTION"][key]/len(list(activations["KSECTION"][key].keys()))).item()
activations["COVERAGE"]["KSECTION"][key] = k_sec_coverage
print(f"KSECTION COVERAGE, '{key}'-Activation: {activations['COVERAGE']['KSECTION'][key]}")
coverage_dict = activations["COVERAGE"].copy()
return activations, coverage_dict
def compute_boundary_strong_coverage(activations):
if not "COVERAGE" in activations:
activations["COVERAGE"] = {}
if not "BOUNDARY" in activations["COVERAGE"]:
activations["COVERAGE"]["BOUNDARY"] = {}
if not "STRONG" in activations["COVERAGE"]:
activations["COVERAGE"]["STRONG"] = {}
for key in activations["BOUNDARY"].keys():
activations['BOUNDARY'][key]["strong"] = activations['BOUNDARY'][key]["strong"].squeeze(0)
activations['BOUNDARY'][key]["weak"] = activations['BOUNDARY'][key]["weak"].squeeze(0)
strong = torch.sum(activations['BOUNDARY'][key]["strong"] > 0)/len(activations['BOUNDARY'][key]['strong'])
weak = torch.sum(activations['BOUNDARY'][key]["weak"] > 0)/len(activations['BOUNDARY'][key]['weak'])
activations["COVERAGE"]["BOUNDARY"][key] = ((strong+weak)/2).item()
activations["COVERAGE"]["STRONG"][key] = strong.item()
print(f"BOUNDARY COVERAGE, '{key}'-Activation: {activations['COVERAGE']['BOUNDARY'][key]}")
print(f"STRONG COVERAGE, '{key}'-Activation: {activations['COVERAGE']['STRONG'][key]}")
coverage_dict = activations["COVERAGE"].copy()
return activations, coverage_dict