-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
35 lines (27 loc) · 1.13 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import keras
import sys
import h5py
import numpy as np
from architecture import *
clean_data_filename = str(sys.argv[1])
poisoned_data_filename = str(sys.argv[2])
model_filename = str(sys.argv[3])
def data_loader(filepath):
data = h5py.File(filepath, 'r')
x_data = np.array(data['data'])
y_data = np.array(data['label'])
x_data = x_data.transpose((0,2,3,1))
return x_data, y_data
def main():
cl_x_test, cl_y_test = data_loader(clean_data_filename)
bd_x_test, bd_y_test = data_loader(poisoned_data_filename)
bd_model = keras.models.load_model(model_filename)
# bd_model = keras.models.load_model(model_filename, custom_objects={"ChannelPruningLayer": ChannelPruningLayer, "CompareAndSelectLayer": CompareAndSelectLayer})
cl_label_p = np.argmax(bd_model.predict(cl_x_test), axis=1)
clean_accuracy = np.mean(np.equal(cl_label_p, cl_y_test))*100
print('Clean Classification accuracy:', clean_accuracy)
bd_label_p = np.argmax(bd_model.predict(bd_x_test), axis=1)
asr = np.mean(np.equal(bd_label_p, bd_y_test))*100
print('Attack Success Rate:', asr)
if __name__ == '__main__':
main()