-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlog_functions.py
97 lines (78 loc) · 3.19 KB
/
log_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
def log_features_dlgn(model,bias_log=False):
'''
This function returns the log of the features of the gates in model.
Returns a 2-d tensor with rows of shape (num_gates, data_dim)
'''
weight = []
bias = []
for name, param in model.named_parameters():
for i in range(0,model.depth):
if name == 'gates.'+str(i)+'.weight':
weight.append(param.data)
if bias_log:
if name == 'gates.'+str(i)+'.bias':
bias.append(param.data)
Feature_list = [weight[0]]
for w in weight[1:]:
Feature_list.append(w @ Feature_list[-1])
features = torch.cat(Feature_list, axis = 0)
return features #make it to .to("cpu") if you want to use it in numpy
def log_features_dlgn_sf(model,bias_log=False):
'''
This function returns the log of the features of the gates in model.
Returns a 2-d tensor with rows of shape (num_gates, data_dim)
'''
weight = []
bias = []
for name, param in model.named_parameters():
for i in range(0,model.depth):
if name == 'gates.'+str(i)+'.weight':
weight.append(param.data)
if bias_log:
if name == 'gates.'+str(i)+'.bias':
bias.append(param.data)
Feature_list = []
for w in weight:
Feature_list.append(w)
features = torch.cat(Feature_list, axis = 0)
return features #make it to .to("cpu") if you want to use it in numpy
def log_features_DLGN_kernel(model,bias=False): ## Don't use it for DLGN model
weight = []
bias = []
for name, param in model.named_parameters():
for i in range(0,model.depth):
if name == 'gates.'+str(i):
weight.append(param.data)
if bias:
if name == 'gates.'+str(i)+'.bias':
bias.append(param.data)
Feature_list = [weight[0].T]
for w in weight[1:]:
Feature_list.append(w.T @ Feature_list[-1])
features = torch.cat(Feature_list, axis = 0)
return features #make it to .to("cpu") if you want to use it in numpy
def feature_stats(features,data_dim=18,tree_depth=4,dim_in=18,threshold=0.1,req_index=False): #can set tree_depth=0 to get root node stats...
'''
Returns the count of features that are close to the standard basis vectors within a threshold
Can return the indices of the features as well if req_index=True
count is a 1-d tensor of length 2**tree_depth-1
index is a list of lists of length 2**tree_depth-1 with each list containing the indices of the
features that are close to the standard basis vector corresponding to that node.
'''
num_nodes = 2**tree_depth-1
tensor = torch.eye(data_dim) #standard basis
y=torch.randn(dim_in)
rand_point=y/torch.norm(y, p=2)
count = torch.zeros(num_nodes)
index = [[]]*num_nodes
for ind,item in enumerate(features):
for i in range(num_nodes):
if torch.linalg.vector_norm(item/(item.norm(dim=0, p=2))-tensor[i]) < threshold or torch.linalg.vector_norm(item/(item.norm(dim=0, p=2))+tensor[i]) < threshold:
count[i] += 1
if req_index:
index[i].append(ind)
if req_index:
return count,index
else:
return count