forked from isht7/pytorch-deeplab-resnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evalpyt2.py
123 lines (102 loc) · 4.24 KB
/
evalpyt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import scipy
from scipy import ndimage
import cv2
import numpy as np
import sys
#sys.path.insert(0,'/data1/ravikiran/SketchObjPartSegmentation/src/caffe-switch/caffe/python')
#import caffe
import torch
from torch.autograd import Variable
import torchvision.models as models
import torch.nn.functional as F
import deeplab_resnet
from collections import OrderedDict
import os
from os import walk
import matplotlib.pyplot as plt
import torch.nn as nn
from docopt import docopt
docstr = """Evaluate ResNet-DeepLab trained on scenes (VOC 2012),a total of 21 labels including background
Usage:
evalpyt.py [options]
Options:
-h, --help Print this message
--visualize view outputs of each sketch
--snapPrefix=<str> Snapshot [default: VOC12_scenes_]
--testGTpath=<str> Ground truth path prefix [default: data/gt/]
--testIMpath=<str> Sketch images path prefix [default: data/img/]
--NoLabels=<int> The number of different labels in training data, VOC has 21 labels, including background [default: 21]
--gpu0=<int> GPU number [default: 0]
"""
args = docopt(docstr, version='v0.1')
print args
max_label = int(args['--NoLabels'])-1 # labels from 0,1, ... 20(for VOC)
def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n**2).reshape(n, n)
def get_iou(pred,gt):
if pred.shape!= gt.shape:
print 'pred shape',pred.shape, 'gt shape', gt.shape
assert(pred.shape == gt.shape)
gt = gt.astype(np.float32)
pred = pred.astype(np.float32)
count = np.zeros((max_label+1,))
for j in range(max_label+1):
x = np.where(pred==j)
p_idx_j = set(zip(x[0].tolist(),x[1].tolist()))
x = np.where(gt==j)
GT_idx_j = set(zip(x[0].tolist(),x[1].tolist()))
#pdb.set_trace()
n_jj = set.intersection(p_idx_j,GT_idx_j)
u_jj = set.union(p_idx_j,GT_idx_j)
if len(GT_idx_j)!=0:
count[j] = float(len(n_jj))/float(len(u_jj))
result_class = count
Aiou = np.sum(result_class[:])/float(len(np.unique(gt)))
return Aiou
gpu0 = int(args['--gpu0'])
im_path = args['--testIMpath']
model = deeplab_resnet.Res_Deeplab(int(args['--NoLabels']))
model.eval()
counter = 0
model.cuda(gpu0)
snapPrefix = args['--snapPrefix']
gt_path = args['--testGTpath']
img_list = open('data/list/val.txt').readlines()
for iter in range(1,21): #TODO set the (different iteration)models that you want to evaluate on. Models are saved during training after each 1000 iters by default.
saved_state_dict = torch.load(os.path.join('data/snapshots/',snapPrefix+str(iter)+'000.pth'))
if counter==0:
print snapPrefix
counter+=1
model.load_state_dict(saved_state_dict)
hist = np.zeros((max_label+1, max_label+1))
pytorch_list = [];
for i in img_list:
img = np.zeros((513,513,3));
img_temp = cv2.imread(os.path.join(im_path,i[:-1]+'.jpg')).astype(float)
img_original = img_temp
img_temp[:,:,0] = img_temp[:,:,0] - 104.008
img_temp[:,:,1] = img_temp[:,:,1] - 116.669
img_temp[:,:,2] = img_temp[:,:,2] - 122.675
img[:img_temp.shape[0],:img_temp.shape[1],:] = img_temp
gt = cv2.imread(os.path.join(gt_path,i[:-1]+'.png'),0)
#gt[gt==255] = 0
output = model(Variable(torch.from_numpy(img[np.newaxis, :].transpose(0,3,1,2)).float(),volatile = True).cuda(gpu0))
interp = nn.UpsamplingBilinear2d(size=(513, 513))
output = interp(output[3]).cpu().data[0].numpy()
output = output[:,:img_temp.shape[0],:img_temp.shape[1]]
output = output.transpose(1,2,0)
output = np.argmax(output,axis = 2)
if args['--visualize']:
plt.subplot(3, 1, 1)
plt.imshow(img_original)
plt.subplot(3, 1, 2)
plt.imshow(gt)
plt.subplot(3, 1, 3)
plt.imshow(output)
plt.show()
iou_pytorch = get_iou(output,gt)
pytorch_list.append(iou_pytorch)
hist += fast_hist(gt.flatten(),output.flatten(),max_label+1)
miou = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
print 'pytorch',iter,"Mean iou = ",np.sum(miou)/len(miou)