forked from BeenKim/MMD-critic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmmd.py
150 lines (121 loc) · 5.38 KB
/
mmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# maintained by [email protected]
import numpy as np
# from mpi4py import MPI
import sys
import math
##############################################################################################################################
# function to select criticisms
# ARGS:
# K: Kernel matrix
# selectedprotos: prototypes already selected
# m : number of criticisms to be selected
# reg: regularizer type.
# is_K_sparse: True means K is the pre-computed csc sparse matrix? False means it is a dense matrix.
# RETURNS: indices selected as criticisms
##############################################################################################################################
def select_criticism_regularized(K, selectedprotos, m, reg='logdet', is_K_sparse=True):
n = np.shape(K)[0]
if reg in ['None','logdet','iterative']:
pass
else:
print "wrong regularizer :" + regularizer
exit(1)
options = dict()
selected = np.array([], dtype=int)
candidates2 = np.setdiff1d(range(n), selectedprotos)
inverse_of_prev_selected = None # should be a matrix
if is_K_sparse:
colsum = np.array(K.sum(0)).ravel()/n
else:
colsum = np.sum(K, axis=0)/n
for i in range(m):
maxx = -sys.float_info.max
argmax = -1
candidates = np.setdiff1d(candidates2, selected)
s1array = colsum[candidates]
temp = K[selectedprotos, :][:, candidates]
if is_K_sparse:
s2array = temp.sum(0)
else:
s2array = np.sum(temp, axis=0)
s2array = s2array / (len(selectedprotos))
s1array = np.abs(s1array - s2array)
if reg == 'logdet':
if inverse_of_prev_selected is not None: # first call has been made already
temp = K[selected, :][:, candidates]
if is_K_sparse:
temp2 = temp.transpose().dot(inverse_of_prev_selected)
regularizer = temp.transpose().multiply(temp2)
regcolsum = regularizer.sum(1).ravel()# np.sum(regularizer, axis=0)
regularizer = np.abs(K.diagonal()[candidates] - regcolsum)
else:
# hadamard product
temp2 = np.array(np.dot(inverse_of_prev_selected, temp))
regularizer = temp2 * temp
regcolsum = np.sum(regularizer, axis=0)
regularizer = np.log(np.abs(np.diagonal(K)[candidates] - regcolsum))
s1array = s1array + regularizer
else:
if is_K_sparse:
s1array = s1array - np.log(np.abs(K.diagonal()[candidates]))
else:
s1array = s1array - np.log(np.abs(np.diagonal(K)[candidates]))
argmax = candidates[np.argmax(s1array)]
maxx = np.max(s1array)
selected = np.append(selected, argmax)
if reg == 'logdet':
KK = K[selected,:][:,selected]
if is_K_sparse:
KK = KK.todense()
inverse_of_prev_selected = np.linalg.inv(KK) # shortcut
if reg == 'iterative':
selectedprotos = np.append(selectedprotos, argmax)
return selected
##############################################################################################################################
# Function choose m of all rows by MMD as per kernelfunc
# ARGS:
# K : kernel matrix
# candidate_indices : array of potential choices for selections, returned values are chosen from these indices
# m: number of selections to be made
# is_K_sparse: True means K is the pre-computed csc sparse matrix? False means it is a dense matrix.
# RETURNS: subset of candidate_indices which are selected as prototypes
##############################################################################################################################
def greedy_select_protos(K, candidate_indices, m, is_K_sparse=False):
if len(candidate_indices) != np.shape(K)[0]:
K = K[:,candidate_indices][candidate_indices,:]
n = len(candidate_indices)
# colsum = np.array(K.sum(0)).ravel() # same as rowsum
if is_K_sparse:
colsum = 2*np.array(K.sum(0)).ravel() / n
else:
colsum = 2*np.sum(K, axis=0) / n
selected = np.array([], dtype=int)
value = np.array([])
for i in range(m):
maxx = -sys.float_info.max
argmax = -1
candidates = np.setdiff1d(range(n), selected)
s1array = colsum[candidates]
if len(selected) > 0:
temp = K[selected, :][:, candidates]
if is_K_sparse:
# s2array = temp.sum(0) *2
s2array = temp.sum(0) * 2 + K.diagonal()[candidates]
else:
s2array = np.sum(temp, axis=0) *2 + np.diagonal(K)[candidates]
s2array = s2array/(len(selected) + 1)
s1array = s1array - s2array
else:
if is_K_sparse:
s1array = s1array - (np.abs(K.diagonal()[candidates]))
else:
s1array = s1array - (np.abs(np.diagonal(K)[candidates]))
argmax = candidates[np.argmax(s1array)]
# print "max %f" %np.max(s1array)
selected = np.append(selected, argmax)
# value = np.append(value,maxx)
KK = K[selected, :][:, selected]
if is_K_sparse:
KK = KK.todense()
inverse_of_prev_selected = np.linalg.inv(KK) # shortcut
return candidate_indices[selected]