-
Notifications
You must be signed in to change notification settings - Fork 20
/
example_adult_mcar.py
255 lines (209 loc) · 9.01 KB
/
example_adult_mcar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import numpy as np
import seaborn as sbn
from collections import defaultdict
from scipy.stats import mode, itemfreq
from scipy import delete
from sklearn.metrics import confusion_matrix
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from missing_data_imputation import Imputer
plt.rcParams.update({'figure.autolayout': True})
# declare csv headers
x = np.genfromtxt('adult-train-raw', delimiter=', ', dtype=object)
# remove redundant education-number feature
x = delete(x, (4, 14), 1)
# enumerate parameters and instantiate Imputer
imp = Imputer()
missing_data_symbol = '?'
miss_data_cond = lambda x: x == missing_data_symbol
cat_cols = (1, 3, 4, 5, 6, 7, 8, 12)
n_neighbors = 5
miss_data_rows, miss_data_cols = np.where(miss_data_cond(x))
# remove missing data, which we is are MNAR in the ADULT dataset
x = np.delete(x, miss_data_rows, axis=0)
ratios = np.arange(10, 100, 10)
# monotone False must be fixed
monotone = True
def perturbate_data(x, cat_cols, ratio, missing_data_symbol,
monotone=False, in_place=False):
"""Perturbs data by substituting existing values with missing data symbol
such that each feature has a minimum missing data ratio
"""
def zero():
return 0
if in_place:
data = x
else:
data = np.copy(x)
n_perturbations = int(len(x) * ratio)
if monotone:
cat_cols = np.random.choice(cat_cols, int(len(cat_cols) * .3))
rows = np.random.randint(0, len(data), n_perturbations)
cols = np.random.choice(cat_cols, n_perturbations)
data[rows, cols] = missing_data_symbol
miss_dict = defaultdict(list)
for (row, col) in np.dstack((rows, cols))[0]:
miss_dict[col].append(row)
else:
# slow version
row_col_miss = defaultdict(zero)
miss_dict = defaultdict(list)
i = 0
while i < n_perturbations:
row = np.random.randint(0, len(data))
col = np.random.choice(cat_cols)
# proceed if less than half the features are missing
if row_col_miss[row] < len(cat_cols) * 0.5 and data[row, col] != missing_data_symbol:
data[row, col] = missing_data_symbol
row_col_miss[row] += 1
miss_dict[col].append(row)
i += 1
return data, miss_dict
def compute_histogram(data, labels):
histogram = dict(itemfreq(data))
for label in labels:
if label not in histogram:
histogram[label] = .0
return histogram
def compute_error_rate(y, y_hat, feat_imp_ids):
error_rate = {}
for col, ids in feat_imp_ids.items():
errors = sum(y[ids, col] != y_hat[ids, col])
error_rate[col] = errors / float(len(ids))
return error_rate
# helper function to plot histograms
def plot_histogram(freq_data, labels, axes, axis, width, title,
color_mapping):
n_methods = len(freq_data.keys())
labels = sorted(freq_data.values()[0].keys())
bins = np.arange(len(labels))
for i in xrange(n_methods):
key = sorted(freq_data.keys())[i]
offset = i*2*width/float(n_methods)
values = [freq_data[key][label] for label in labels]
axes.flat[axis].bar(bins+offset, values,
width, label=key,
color=plt.cm.Set1(color_mapping[key]),
align='center')
axes.flat[axis].set_xlim(bins[0]-0.5, bins[-1]+width+0.5)
axes.flat[axis].set_title(title)
axes.flat[axis].set_xticks(bins + width)
axes.flat[axis].set_xticklabels(labels, rotation=90,
fontsize='small')
axes.flat[axis].legend(loc='best', prop={'size': 8},
shadow=True, fancybox=True)
def plot_confusion_matrix(y, y_predict, axes, axis, title='',
normalize=True, add_text=False):
"""Plots a confusion matrix given labels and predicted labels
Parameters
----------
y: ground truth labels <int array>
y_predict: predicted labels <int array>
"""
conf_mat = confusion_matrix(y, y_predict)
if normalize:
conf_mat_norm = conf_mat / conf_mat.sum(axis=1).astype(float)[:,np.newaxis]
conf_mat_norm = np.nan_to_num(conf_mat_norm)
axes.flat[axis].imshow(conf_mat_norm, cmap=plt.cm.Blues,
interpolation='nearest')
if axis < axes.shape[1]:
axes.flat[axis].set_title(title)
# add text to confusion matrix
if add_text:
for x in xrange(conf_mat.shape[0]):
for y in xrange(conf_mat.shape[0]):
if conf_mat[x, y] > 0:
axes.flat[axis].annotate(str(conf_mat[x, y]),
xy=(y, x),
horizontalalignment='center',
verticalalignment='center')
for ratio in ratios:
print 'Experiments on {}% missing data'.format(ratio)
pert_data, feat_imp_ids = perturbate_data(x, cat_cols, .01*ratio,
missing_data_symbol,
monotone=monotone)
miss_data_cols = feat_imp_ids.keys()
print 'Missing data cols {}'.format(miss_data_cols)
data_dict = {}
data_dict['RawData'] = pert_data
# drop observations with missing variables
print 'imputing with drop'
data_dict['Drop'] = imp.drop(pert_data, miss_data_cond)
# replace missing values with random existing values
print 'imputing with random replacement'
data_dict['RandomReplace'] = imp.replace(pert_data, miss_data_cond)
# replace missing values with feature summary
print 'imputing with feature summarization (mode)'
summ_func = lambda x: mode(x)[0]
data_dict['Mode'] = imp.summarize(pert_data, summ_func, miss_data_cond)
# replace missing data with predictions using random forest
print 'imputing with Random Forest'
data_dict['RandomForest'] = imp.predict(pert_data, cat_cols, miss_data_cond)
# replace missing data with values obtained after factor analysis
print 'imputing with PCA'
data_dict['PCA'] = imp.factor_analysis(pert_data, cat_cols, miss_data_cond)
# replace missing data with knn
print 'imputing with K-Nearest Neighbors'
data_dict['KNN'] = imp.knn(pert_data, n_neighbors, np.mean, miss_data_cond,
cat_cols)
conf_methods = ['RandomReplace', 'Mode', 'RandomForest', 'PCA', 'KNN']
methods = ['RawData', 'Drop', 'RandomReplace', 'Mode', 'RandomForest',
'PCA', 'KNN']
color_mapping = {}
for i in xrange(len(methods)):
color_mapping[methods[i]] = (i+1) / float(len(methods))
###########################
# plot confusion matrices #
###########################
fig, axes = plt.subplots(len(miss_data_cols), len(conf_methods),
figsize=(8, 8))
axis = 0
for col in miss_data_cols:
for key in conf_methods:
plot_confusion_matrix(x[:, col], data_dict[key][:, col], axes, axis,
key)
axis += 1
plt.savefig('images/mcar_mono_{}_conf_matrix_miss_ratio_{}.png'.format(monotone, ratio), dpi=300)
#######################
# compute error rates #
#######################
error_rates = {}
for method in conf_methods:
error_rates[method] = compute_error_rate(x, data_dict[method],
feat_imp_ids)
# set plot params
fig, axes = plt.subplots(1 + (len(miss_data_cols)/3), 3, figsize=(16, 9))
width = .25
###############################
# compute and plot histograms #
###############################
for i in xrange(len(miss_data_cols)):
col = miss_data_cols[i]
labels = np.unique(x[:, col])
freq_data = {}
for key, data in data_dict.items():
freq_data[key] = compute_histogram(data[:, col], labels)
plot_histogram(freq_data, labels, axes, i, width, col, color_mapping)
plt.savefig('images/mcar_mono_{}_dist_ratio_{}.png'.format(monotone, ratio), dpi=300)
########################
# plot error rate bars #
########################
fig, axes = plt.subplots(1, 1, figsize=(10, 5))
n_methods = len(error_rates.keys())
bins = np.arange(len(feat_imp_ids))
width = .25
for i in xrange(n_methods):
key = sorted(error_rates.keys())[i]
offset = i*width/float(n_methods)
values = [error_rates[key][feat] for feat in sorted(error_rates[key])]
axes.bar(bins+offset, values, width, label=key,
color=plt.cm.Set1(color_mapping[key]),
align='center')
axes.set_xlim(bins[0]-0.5, bins[-1]+width+0.5)
axes.set_xticks(bins + width)
axes.set_xticklabels(sorted(feat_imp_ids.keys()))
axes.legend(loc='best', prop={'size': 8},
shadow=True, fancybox=True)
axes.set_title('Error rates')
plt.savefig('images/mcar_mono_{}_error_miss_ratio_{}.png'.format(monotone, ratio), dpi=300)