Skip to content

Commit

Permalink
Add model, train and eval scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
syb7573330 committed Feb 5, 2018
1 parent 9dce447 commit a273841
Show file tree
Hide file tree
Showing 6 changed files with 635 additions and 0 deletions.
173 changes: 173 additions & 0 deletions sem_seg/batch_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import argparse
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)
from model import *
import indoor3d_util

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during training [default: 1]')
parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]')
parser.add_argument('--model_path', required=True, help='model checkpoint file path')
parser.add_argument('--dump_dir', required=True, help='dump folder path')
parser.add_argument('--output_filelist', required=True, help='TXT filename, filelist, each line is an output for a room')
parser.add_argument('--room_data_filelist', required=True, help='TXT filename, filelist, each line is a test room data label file.')
parser.add_argument('--no_clutter', action='store_true', help='If true, donot count the clutter class')
parser.add_argument('--visu', action='store_true', help='Whether to output OBJ file for prediction visualization.')
FLAGS = parser.parse_args()

BATCH_SIZE = FLAGS.batch_size
NUM_POINT = FLAGS.num_point
MODEL_PATH = FLAGS.model_path
GPU_INDEX = FLAGS.gpu
DUMP_DIR = FLAGS.dump_dir
if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR)
LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')
ROOM_PATH_LIST = [os.path.join(ROOT_DIR,line.rstrip()) for line in open(FLAGS.room_data_filelist)]

NUM_CLASSES = 13

def log_string(out_str):
LOG_FOUT.write(out_str+'\n')
LOG_FOUT.flush()
print(out_str)

def evaluate():
is_training = False

with tf.device('/gpu:'+str(GPU_INDEX)):
pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT)
is_training_pl = tf.placeholder(tf.bool, shape=())

pred = get_model(pointclouds_pl, is_training_pl)
loss = get_loss(pred, labels_pl)
pred_softmax = tf.nn.softmax(pred)

saver = tf.train.Saver()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)

saver.restore(sess, MODEL_PATH)
log_string("Model restored.")

ops = {'pointclouds_pl': pointclouds_pl,
'labels_pl': labels_pl,
'is_training_pl': is_training_pl,
'pred': pred,
'pred_softmax': pred_softmax,
'loss': loss}

total_correct = 0
total_seen = 0
fout_out_filelist = open(FLAGS.output_filelist, 'w')
for room_path in ROOM_PATH_LIST:
out_data_label_filename = os.path.basename(room_path)[:-4] + '_pred.txt'
out_data_label_filename = os.path.join(DUMP_DIR, out_data_label_filename)
out_gt_label_filename = os.path.basename(room_path)[:-4] + '_gt.txt'
out_gt_label_filename = os.path.join(DUMP_DIR, out_gt_label_filename)

print(room_path, out_data_label_filename)
# Evaluate room one by one.
a, b = eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename)
total_correct += a
total_seen += b
fout_out_filelist.write(out_data_label_filename+'\n')
fout_out_filelist.close()
log_string('all room eval accuracy: %f'% (total_correct / float(total_seen)))

def eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename):
error_cnt = 0
is_training = False
total_correct = 0
total_seen = 0
loss_sum = 0
total_seen_class = [0 for _ in range(NUM_CLASSES)]
total_correct_class = [0 for _ in range(NUM_CLASSES)]

if FLAGS.visu:
fout = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_pred.obj'), 'w')
fout_gt = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_gt.obj'), 'w')
fout_real_color = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_real_color.obj'), 'w')
fout_data_label = open(out_data_label_filename, 'w')
fout_gt_label = open(out_gt_label_filename, 'w')

current_data, current_label = indoor3d_util.room2blocks_wrapper_normalized(room_path, NUM_POINT)
current_data = current_data[:,0:NUM_POINT,:]
current_label = np.squeeze(current_label)
# Get room dimension..
data_label = np.load(room_path)
data = data_label[:,0:6]
max_room_x = max(data[:,0])
max_room_y = max(data[:,1])
max_room_z = max(data[:,2])

file_size = current_data.shape[0]
num_batches = file_size // BATCH_SIZE
print(file_size)


for batch_idx in range(num_batches):
start_idx = batch_idx * BATCH_SIZE
end_idx = (batch_idx+1) * BATCH_SIZE
cur_batch_size = end_idx - start_idx

feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :],
ops['labels_pl']: current_label[start_idx:end_idx],
ops['is_training_pl']: is_training}
loss_val, pred_val = sess.run([ops['loss'], ops['pred_softmax']],
feed_dict=feed_dict)

if FLAGS.no_clutter:
pred_label = np.argmax(pred_val[:,:,0:12], 2) # BxN
else:
pred_label = np.argmax(pred_val, 2) # BxN

# Save prediction labels to OBJ file
for b in range(BATCH_SIZE):
pts = current_data[start_idx+b, :, :]
l = current_label[start_idx+b,:]
pts[:,6] *= max_room_x
pts[:,7] *= max_room_y
pts[:,8] *= max_room_z
pts[:,3:6] *= 255.0
pred = pred_label[b, :]
for i in range(NUM_POINT):
color = indoor3d_util.g_label2color[pred[i]]
color_gt = indoor3d_util.g_label2color[current_label[start_idx+b, i]]
if FLAGS.visu:
fout.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color[0], color[1], color[2]))
fout_gt.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color_gt[0], color_gt[1], color_gt[2]))
fout_data_label.write('%f %f %f %d %d %d %f %d\n' % (pts[i,6], pts[i,7], pts[i,8], pts[i,3], pts[i,4], pts[i,5], pred_val[b,i,pred[i]], pred[i]))
fout_gt_label.write('%d\n' % (l[i]))

correct = np.sum(pred_label == current_label[start_idx:end_idx,:])
total_correct += correct
total_seen += (cur_batch_size*NUM_POINT)
loss_sum += (loss_val*BATCH_SIZE)
for i in range(start_idx, end_idx):
for j in range(NUM_POINT):
l = current_label[i, j]
total_seen_class[l] += 1
total_correct_class[l] += (pred_label[i-start_idx, j] == l)

log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT)))
log_string('eval accuracy: %f'% (total_correct / float(total_seen)))
fout_data_label.close()
fout_gt_label.close()
if FLAGS.visu:
fout.close()
fout_gt.close()
return total_correct, total_seen


if __name__=='__main__':
with tf.Graph().as_default():
evaluate()
LOG_FOUT.close()
44 changes: 44 additions & 0 deletions sem_seg/eval_iou_accuracy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np

pred_data_label_filenames = []
for i in range(1,7):
file_name = 'log{}/output_filelist.txt'.format(i)
pred_data_label_filenames += [line.rstrip() for line in open(file_name)]

gt_label_filenames = [f.rstrip('_pred\.txt') + '_gt.txt' for f in pred_data_label_filenames]

num_room = len(gt_label_filenames)

gt_classes = [0 for _ in range(13)]
positive_classes = [0 for _ in range(13)]
true_positive_classes = [0 for _ in range(13)]

for i in range(num_room):
print(i)
data_label = np.loadtxt(pred_data_label_filenames[i])
pred_label = data_label[:,-1]
gt_label = np.loadtxt(gt_label_filenames[i])
print(gt_label.shape)
for j in xrange(gt_label.shape[0]):
gt_l = int(gt_label[j])
pred_l = int(pred_label[j])
gt_classes[gt_l] += 1
positive_classes[pred_l] += 1
true_positive_classes[gt_l] += int(gt_l==pred_l)


print(gt_classes)
print(positive_classes)
print(true_positive_classes)

print('Overall accuracy: {0}'.format(sum(true_positive_classes)/float(sum(positive_classes))))

print 'IoU:'
iou_list = []
for i in range(13):
iou = true_positive_classes[i]/float(gt_classes[i]+positive_classes[i]-true_positive_classes[i])
print(iou)
iou_list.append(iou)

print 'avg IoU:'
print(sum(iou_list)/13.0)
120 changes: 120 additions & 0 deletions sem_seg/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import tensorflow as tf
import math
import time
import numpy as np
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(BASE_DIR, '../models'))
import tf_util

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32,
shape=(batch_size, num_point, 9))
labels_pl = tf.placeholder(tf.int32,
shape=(batch_size, num_point))
return pointclouds_pl, labels_pl

def get_model(point_cloud, is_training, bn_decay=None):
""" ConvNet baseline, input is BxNx9 gray image """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
input_image = tf.expand_dims(point_cloud, -1)

k = 30

adj = tf_util.pairwise_distance(point_cloud[:, :, 6:])
nn_idx = tf_util.knn(adj, k=k) # (batch, num_points, k)
edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k)

out1 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv1', bn_decay=bn_decay, is_dist=True)

out2 = tf_util.conv2d(out1, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv2', bn_decay=bn_decay, is_dist=True)

net_max_1 = tf.reduce_max(out2, axis=-2, keep_dims=True)
net_mean_1 = tf.reduce_mean(out2, axis=-2, keep_dims=True)

out3 = tf_util.conv2d(tf.concat([net_max_1, net_mean_1], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv3', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out3, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out3, nn_idx=nn_idx, k=k)

out4 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv4', bn_decay=bn_decay, is_dist=True)

net_max_2 = tf.reduce_max(out4, axis=-2, keep_dims=True)
net_mean_2 = tf.reduce_mean(out4, axis=-2, keep_dims=True)

out5 = tf_util.conv2d(tf.concat([net_max_2, net_mean_2], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv5', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out5, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out5, nn_idx=nn_idx, k=k)

out6 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv6', bn_decay=bn_decay, is_dist=True)

net_max_3 = tf.reduce_max(out6, axis=-2, keep_dims=True)
net_mean_3 = tf.reduce_mean(out6, axis=-2, keep_dims=True)

out7 = tf_util.conv2d(tf.concat([net_max_3, net_mean_3], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv7', bn_decay=bn_decay, is_dist=True)

out8 = tf_util.conv2d(tf.concat([out3, out5, out7], axis=-1), 1024, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv8', bn_decay=bn_decay, is_dist=True)

out_max = tf_util.max_pool2d(out8, [num_point,1], padding='VALID', scope='maxpool')

expand = tf.tile(out_max, [1, num_point, 1, 1])

concat = tf.concat(axis=3, values=[expand,
net_max_1,
net_mean_1,
out3,
net_max_2,
net_mean_2,
out5,
net_max_3,
net_mean_3,
out7,
out8])

# CONV
net = tf_util.conv2d(concat, 512, [1,1], padding='VALID', stride=[1,1],
bn=True, is_training=is_training, scope='seg/conv1', is_dist=True)
net = tf_util.conv2d(net, 256, [1,1], padding='VALID', stride=[1,1],
bn=True, is_training=is_training, scope='seg/conv2', is_dist=True)
net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='dp1')
net = tf_util.conv2d(net, 13, [1,1], padding='VALID', stride=[1,1],
activation_fn=None, scope='seg/conv3', is_dist=True)
net = tf.squeeze(net, [2])

return net

def get_loss(pred, label):
""" pred: B,N,13; label: B,N """
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label)
return tf.reduce_mean(loss)
6 changes: 6 additions & 0 deletions sem_seg/test_job.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python batch_inference.py --model_path log1/epoch_60.ckpt --dump_dir log1/dump --output_filelist log1/output_filelist.txt --room_data_filelist meta/area1_data_label.txt
python batch_inference.py --model_path log2/epoch_60.ckpt --dump_dir log2/dump --output_filelist log2/output_filelist.txt --room_data_filelist meta/area2_data_label.txt
python batch_inference.py --model_path log3/epoch_60.ckpt --dump_dir log3/dump --output_filelist log3/output_filelist.txt --room_data_filelist meta/area3_data_label.txt
python batch_inference.py --model_path log4/epoch_60.ckpt --dump_dir log4/dump --output_filelist log4/output_filelist.txt --room_data_filelist meta/area4_data_label.txt
python batch_inference.py --model_path log5/epoch_60.ckpt --dump_dir log5/dump --output_filelist log5/output_filelist.txt --room_data_filelist meta/area5_data_label.txt
python batch_inference.py --model_path log6/epoch_60.ckpt --dump_dir log6/dump --output_filelist log6/output_filelist.txt --room_data_filelist meta/area6_data_label.txt
Loading

0 comments on commit a273841

Please sign in to comment.