-
Notifications
You must be signed in to change notification settings - Fork 0
/
ensemble_selection.py
165 lines (137 loc) · 6.18 KB
/
ensemble_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import pandas as pd
#import matplotlib.pyplot as plt
import time
import os
import copy
import argparse
from sklearn.model_selection import StratifiedKFold
import datetime
from PIL import Image
import torch.nn.functional as F
def cross_entropy(y, p):
p /= p.sum(1).reshape(-1,1)
return F.nll_loss(torch.log(torch.tensor(p)), torch.tensor(y)).numpy()
def weighted_cross_entropy(y, p):
p /= p.sum(1).reshape(-1,1)
w_arr = np.array([0.53, 0.3, 0.0])
return np.sum([F.nll_loss(torch.log(torch.tensor(p[y==c])), torch.tensor(y[y==c])).numpy()*w_arr[c] for c in range(3)])
class ensembleSelection:
def __init__(self, metric):
self.metric = metric
def _compare(self, sc1, sc2):
if sc1 < sc2:
return True
return False
def _initialize(self, X_p, y):
"""
This function finds the id of the best validation probabiltiy
"""
current_sc = self.metric(y, X_p[0])
ind = 0
for i in range(1, X_p.shape[0]):
sc = self.metric(y, X_p[i])
if self._compare(sc, current_sc):
current_sc = sc
ind = i
return ind, current_sc
def es_with_replacement(self, X_p, Xtest_p, y):
best_ind, best_sc = self._initialize(X_p, y)
current_sc = best_sc
sumP = np.copy(X_p[best_ind])
sumP_test = np.copy(Xtest_p[best_ind])
i = 1
# find the best combintation of input models' reuslts
while True:
i += 1
ind = -1
for m in range(X_p.shape[0]):
#check if adding model m to the combination of best models will improve the results or not
sc = self.metric(y, (sumP*X_p[m])**(1/i))
if self._compare(sc, current_sc):
current_sc = sc
ind = m
if ind>-1:
sumP *= X_p[ind]
sumP_test *= Xtest_p[ind]
else:
break
sumP = sumP**(1/(i-1))
sumP_test = sumP_test**(1/(i-1))
sumP /= sumP.sum(1).reshape(-1,1)
sumP_test /= sumP_test.sum(1).reshape(-1,1)
return current_sc, sumP, sumP_test
def es_with_bagging(self, X_p, Xtest_p, y, f = 0.5, n_bags = 20):
list_of_indecies = [i for i in range(X_p.shape[0])]
bag_size = int(f*X_p.shape[0])
sumP = None
sumP_test = None
for i in range(n_bags):
#create a random subset (bag) of models
model_weight = [0 for j in range(X_p.shape[0])]
rng = np.copy(list_of_indecies)
np.random.shuffle(rng)
rng = rng[:bag_size]
#find the best combination from the input bag
sc, p, ptest = self.es_with_replacement(X_p[rng], Xtest_p[rng], y)
print('bag: %d, sc: %f'%(i, sc))
if sumP is None:
sumP = p
sumP_test = ptest
else:
sumP *= p
sumP_test *= ptest
#combine the reuslts of all bags
sumP = sumP**(1/n_bags)
sumP_test = sumP_test**(1/n_bags)
sumP /= sumP.sum(1).reshape(-1,1)
sumP_test /= sumP_test.sum(1).reshape(-1,1)
sumP[sumP < 1e-6] = 1e-6
sumP_test[sumP_test < 1e-6] = 1e-6
final_sc = self.metric(y, sumP)
print('avg sc: %f'%(final_sc))
return (final_sc, sumP, sumP_test)
parser = argparse.ArgumentParser(description='Data preperation')
parser.add_argument('--train_data_path', help='path to training data folder', default='train_data', type=str)
parser.add_argument('--data_path', help='path to training and test numpy matrices of images', default='.', type=str)
parser.add_argument('--sample_sub_file_path', help='path to sample submission file', default='.', type=str)
parser.add_argument('--library_size', help='number of models to be trained in the library of models', default=50, type=int)
parser.add_argument('--library_path', help='save path for validation and test predictions of the library of models', default='trails', type=str)
parser.add_argument('--final_sub_file_save_path', help='save path for final submission file', default='.', type=str)
args = parser.parse_args()
np.random.seed(4321)
n = args.library_size
#read training gt
train_gts = np.load(os.path.join(args.data_path, 'unique_train_gts_rot_fixed.npy'))
#read validation probability on training data generated from automatuic hypropt trails
#and create a matrix of (N,D,3) where N is the number of models and D is the data size
train_prob = np.array([np.load(os.path.join(args.library_path, 'val_prob_trail_%d.npy'%(i))) for i in range(n)])
#read test probability generated from hypropt trails
#and create a matrix of (N,D,3) where N is the number of models and D is the data size
test_prob = np.array([np.load(os.path.join(args.library_path, 'test_prob_trail_%d.npy'%(i))) for i in range(n)])
ids = np.load('ids.npy').tolist()
#use ensemble selection algorithm to find best combination of models using geometric average
es_obj = ensembleSelection(cross_entropy)
sc, es_train_prob, es_test_prob = es_obj.es_with_bagging(train_prob, test_prob, train_gts, n_bags = 10, f = 0.65)
#detect samples with high confidence for healthy wheat
idx = (np.max(es_test_prob, 1) > 0.7) & (np.argmax(es_test_prob, 1) == 2)
#create another ensemble with more weights for leaf and stem classes
es_obj = ensembleSelection(weighted_cross_entropy)
sc, es_train_prob, es_test_prob = es_obj.es_with_bagging(train_prob, test_prob, train_gts, n_bags = 10, f = 0.65)
#increase the probability of confident samples for healthy wheat
es_test_prob[idx, 0] = 1e-6
es_test_prob[idx, 1] = 1e-6
es_test_prob[idx, 2] = 1.0
#create submission
sub = pd.read_csv(os.path.join(args.sample_sub_file_path, 'sample_submission.csv'))
sub['ID'] = ids
lbl_names = os.listdir(args.train_data_path)
for i, name in enumerate(lbl_names):
sub[name] = es_test_prob[:,i].tolist()
sub.to_csv(os.path.join(args.final_sub_file_save_path, 'final_sub.csv'), index = False)