Skip to content

Commit

Permalink
add evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
WangYueFt committed Mar 24, 2018
1 parent 18294d7 commit d767b0e
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import tensorflow as tf
import numpy as np
import argparse
import socket
import importlib
import time
import os
import scipy.misc
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'models'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
import provider
import pc_util


parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--model', default='dgcnn', help='Model name: dgcnn [default: dgcnn]')
parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 1]')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]')
parser.add_argument('--model_path', default='log/model.ckpt', help='model checkpoint file path [default: log/model.ckpt]')
parser.add_argument('--dump_dir', default='dump', help='dump folder path [dump]')
parser.add_argument('--visu', action='store_true', help='Whether to dump image for error case [default: False]')
FLAGS = parser.parse_args()


BATCH_SIZE = FLAGS.batch_size
NUM_POINT = FLAGS.num_point
MODEL_PATH = FLAGS.model_path
GPU_INDEX = FLAGS.gpu
MODEL = importlib.import_module(FLAGS.model) # import network module
DUMP_DIR = FLAGS.dump_dir
if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR)
LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')

NUM_CLASSES = 40
SHAPE_NAMES = [line.rstrip() for line in \
open(os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/shape_names.txt'))]

HOSTNAME = socket.gethostname()

# ModelNet40 official train/test split
TRAIN_FILES = provider.getDataFiles( \
os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'))
TEST_FILES = provider.getDataFiles(\
os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))

def log_string(out_str):
LOG_FOUT.write(out_str+'\n')
LOG_FOUT.flush()
print(out_str)

def evaluate(num_votes):
is_training = False

with tf.device('/gpu:'+str(GPU_INDEX)):
pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
is_training_pl = tf.placeholder(tf.bool, shape=())

# simple model
pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl)
loss = MODEL.get_loss(pred, labels_pl, end_points)

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Create a session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.log_device_placement = True
sess = tf.Session(config=config)

# Restore variables from disk.
saver.restore(sess, MODEL_PATH)
log_string("Model restored.")

ops = {'pointclouds_pl': pointclouds_pl,
'labels_pl': labels_pl,
'is_training_pl': is_training_pl,
'pred': pred,
'loss': loss}

eval_one_epoch(sess, ops, num_votes)


def eval_one_epoch(sess, ops, num_votes=1, topk=1):
error_cnt = 0
is_training = False
total_correct = 0
total_seen = 0
loss_sum = 0
total_seen_class = [0 for _ in range(NUM_CLASSES)]
total_correct_class = [0 for _ in range(NUM_CLASSES)]
fout = open(os.path.join(DUMP_DIR, 'pred_label.txt'), 'w')
for fn in range(len(TEST_FILES)):
log_string('----'+str(fn)+'----')
current_data, current_label = provider.loadDataFile(TEST_FILES[fn])
current_data = current_data[:,0:NUM_POINT,:]
current_label = np.squeeze(current_label)
print(current_data.shape)

file_size = current_data.shape[0]
num_batches = file_size // BATCH_SIZE
print(file_size)

for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
cur_batch_size = end_idx - start_idx

# Aggregating BEG
batch_loss_sum = 0 # sum of losses for the batch
batch_pred_sum = np.zeros((cur_batch_size, NUM_CLASSES)) # score for classes
batch_pred_classes = np.zeros((cur_batch_size, NUM_CLASSES)) # 0/1 for classes
for vote_idx in range(num_votes):
rotated_data = provider.rotate_point_cloud_by_angle(current_data[start_idx:end_idx, :, :],
vote_idx/float(num_votes) * np.pi * 2)
feed_dict = {ops['pointclouds_pl']: rotated_data,
ops['labels_pl']: current_label[start_idx:end_idx],
ops['is_training_pl']: is_training}
loss_val, pred_val = sess.run([ops['loss'], ops['pred']],
feed_dict=feed_dict)
batch_pred_sum += pred_val
batch_pred_val = np.argmax(pred_val, 1)
for el_idx in range(cur_batch_size):
batch_pred_classes[el_idx, batch_pred_val[el_idx]] += 1
batch_loss_sum += (loss_val * cur_batch_size / float(num_votes))
# pred_val_topk = np.argsort(batch_pred_sum, axis=-1)[:,-1*np.array(range(topk))-1]
# pred_val = np.argmax(batch_pred_classes, 1)
pred_val = np.argmax(batch_pred_sum, 1)
# Aggregating END

correct = np.sum(pred_val == current_label[start_idx:end_idx])
# correct = np.sum(pred_val_topk[:,0:topk] == label_val)
total_correct += correct
total_seen += cur_batch_size
loss_sum += batch_loss_sum

for i in range(start_idx, end_idx):
l = current_label[i]
total_seen_class[l] += 1
total_correct_class[l] += (pred_val[i-start_idx] == l)
fout.write('%d, %d\n' % (pred_val[i-start_idx], l))

if pred_val[i-start_idx] != l and FLAGS.visu: # ERROR CASE, DUMP!
img_filename = '%d_label_%s_pred_%s.jpg' % (error_cnt, SHAPE_NAMES[l],
SHAPE_NAMES[pred_val[i-start_idx]])
img_filename = os.path.join(DUMP_DIR, img_filename)
output_img = pc_util.point_cloud_three_views(np.squeeze(current_data[i, :, :]))
scipy.misc.imsave(img_filename, output_img)
error_cnt += 1

log_string('eval mean loss: %f' % (loss_sum / float(total_seen)))
log_string('eval accuracy: %f' % (total_correct / float(total_seen)))
log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))))

class_accuracies = np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)
for i, name in enumerate(SHAPE_NAMES):
log_string('%10s:\t%0.3f' % (name, class_accuracies[i]))



if __name__=='__main__':
with tf.Graph().as_default():
evaluate(num_votes=12)
LOG_FOUT.close()

0 comments on commit d767b0e

Please sign in to comment.