-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_inception_classifier.py
122 lines (78 loc) · 2.82 KB
/
test_inception_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# --------------------------------------------------------
# Tensorflow Faster R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Jiasen Lu, Jianwei Yang, based on code from Ross Girshick
# --------------------------------------------------------
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import argparse
import torch
from my_inceptionV3_classifier import inceptionV3
from pathlib import Path
import pandas as pd
from cub_dataloader import CUB
#os.environ["CUDA_VISIBLE_DEVICES"] = "2"
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--cuda', dest='cuda',
help='whether use CUDA',
action='store_true')
parser.add_argument('--load', dest='load',
help='model to be used for testing',
default='/export/work/m.bharti/output/cub_inception_30.pth')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
total_ids = 200
args.imdb_name = "cub_train"
args.imdbval_name = "cub_test"
load_name = args.load
# initilize the network here.
inception = inceptionV3(200)
print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
inception.load_state_dict(checkpoint['model'])
inception = inception.cuda()
print('load model successfully!')
if args.cuda:
inception.cuda()
PATH = Path('data/cub')
labels = pd.read_csv(PATH/"image_class_labels.txt", header=None, sep=" ")
labels.columns = ["id", "label"]
train_test = pd.read_csv(PATH/"train_test_split.txt", header=None, sep=" ")
train_test.columns = ["id", "is_train"]
images = pd.read_csv(PATH/"images.txt", header=None, sep=" ")
images.columns = ["id", "name"]
classes = pd.read_csv(PATH/"classes.txt", header=None, sep=" ")
classes.columns = ["id", "class"]
valid_dataset = CUB(PATH, labels, train_test, images, train= False, transform= False)
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=100, shuffle=False, num_workers=0)
preds = {}
trues = {}
inception.eval()
total = 0
correct = 0
for i, (x, y) in enumerate(dataloader):
batch = y.shape[0]
x = x.cuda().float()
y = y.cuda().long()
_, y_pred = inception(x,y)
_, pred = torch.max(y_pred, 1)
correct += (pred.data == y).sum()
total += batch
print('imgs:', total)
print("accuracy", float(correct)/total)