-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate.py
29 lines (20 loc) · 855 Bytes
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import chainer
import fire
from inception import inception_v3
from inception import inception_resnet_v2
_MODEL_LOADERS = {
'InceptionV3': inception_v3.load_inception_v3,
'InceptionResnetV2': inception_resnet_v2.load_inception_resnet_v2,
}
def evaluate_inception(model_type, checkpoint_path, dataset_path, dataset_root='/', gpu=0):
loader = _MODEL_LOADERS[model_type]
model = chainer.links.Classifier(loader(checkpoint_path))
if gpu >= 0:
model.to_gpu(gpu)
dataset = chainer.datasets.LabeledImageDataset(dataset_path, root=dataset_root)
iterator = chainer.iterators.SerialIterator(dataset, 100, repeat=False, shuffle=False)
evaluator = chainer.training.extensions.Evaluator(iterator, model, device=gpu)
result = evaluator()
print(result)
if __name__ == '__main__':
fire.Fire(evaluate_inception)