From a2738411cbdf1558a5fab2f10f462ffe97a64c18 Mon Sep 17 00:00:00 2001 From: Yongbin Sun Date: Sun, 4 Feb 2018 20:13:21 -0500 Subject: [PATCH] Add model, train and eval scripts --- sem_seg/batch_inference.py | 173 +++++++++++++++++++++ sem_seg/eval_iou_accuracy.py | 44 ++++++ sem_seg/model.py | 120 +++++++++++++++ sem_seg/test_job.sh | 6 + sem_seg/train.py | 286 +++++++++++++++++++++++++++++++++++ sem_seg/train_job.sh | 6 + 6 files changed, 635 insertions(+) create mode 100644 sem_seg/batch_inference.py create mode 100644 sem_seg/eval_iou_accuracy.py create mode 100644 sem_seg/model.py create mode 100644 sem_seg/test_job.sh create mode 100644 sem_seg/train.py create mode 100644 sem_seg/train_job.sh diff --git a/sem_seg/batch_inference.py b/sem_seg/batch_inference.py new file mode 100644 index 0000000..a300277 --- /dev/null +++ b/sem_seg/batch_inference.py @@ -0,0 +1,173 @@ +import argparse +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.dirname(BASE_DIR) +sys.path.append(BASE_DIR) +from model import * +import indoor3d_util + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') +parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during training [default: 1]') +parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') +parser.add_argument('--model_path', required=True, help='model checkpoint file path') +parser.add_argument('--dump_dir', required=True, help='dump folder path') +parser.add_argument('--output_filelist', required=True, help='TXT filename, filelist, each line is an output for a room') +parser.add_argument('--room_data_filelist', required=True, help='TXT filename, filelist, each line is a test room data label file.') +parser.add_argument('--no_clutter', action='store_true', help='If true, donot count the clutter class') +parser.add_argument('--visu', action='store_true', help='Whether to output OBJ file for prediction visualization.') +FLAGS = parser.parse_args() + +BATCH_SIZE = FLAGS.batch_size +NUM_POINT = FLAGS.num_point +MODEL_PATH = FLAGS.model_path +GPU_INDEX = FLAGS.gpu +DUMP_DIR = FLAGS.dump_dir +if not os.path.exists(DUMP_DIR): os.mkdir(DUMP_DIR) +LOG_FOUT = open(os.path.join(DUMP_DIR, 'log_evaluate.txt'), 'w') +LOG_FOUT.write(str(FLAGS)+'\n') +ROOM_PATH_LIST = [os.path.join(ROOT_DIR,line.rstrip()) for line in open(FLAGS.room_data_filelist)] + +NUM_CLASSES = 13 + +def log_string(out_str): + LOG_FOUT.write(out_str+'\n') + LOG_FOUT.flush() + print(out_str) + +def evaluate(): + is_training = False + + with tf.device('/gpu:'+str(GPU_INDEX)): + pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) + is_training_pl = tf.placeholder(tf.bool, shape=()) + + pred = get_model(pointclouds_pl, is_training_pl) + loss = get_loss(pred, labels_pl) + pred_softmax = tf.nn.softmax(pred) + + saver = tf.train.Saver() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + sess = tf.Session(config=config) + + saver.restore(sess, MODEL_PATH) + log_string("Model restored.") + + ops = {'pointclouds_pl': pointclouds_pl, + 'labels_pl': labels_pl, + 'is_training_pl': is_training_pl, + 'pred': pred, + 'pred_softmax': pred_softmax, + 'loss': loss} + + total_correct = 0 + total_seen = 0 + fout_out_filelist = open(FLAGS.output_filelist, 'w') + for room_path in ROOM_PATH_LIST: + out_data_label_filename = os.path.basename(room_path)[:-4] + '_pred.txt' + out_data_label_filename = os.path.join(DUMP_DIR, out_data_label_filename) + out_gt_label_filename = os.path.basename(room_path)[:-4] + '_gt.txt' + out_gt_label_filename = os.path.join(DUMP_DIR, out_gt_label_filename) + + print(room_path, out_data_label_filename) + # Evaluate room one by one. + a, b = eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename) + total_correct += a + total_seen += b + fout_out_filelist.write(out_data_label_filename+'\n') + fout_out_filelist.close() + log_string('all room eval accuracy: %f'% (total_correct / float(total_seen))) + +def eval_one_epoch(sess, ops, room_path, out_data_label_filename, out_gt_label_filename): + error_cnt = 0 + is_training = False + total_correct = 0 + total_seen = 0 + loss_sum = 0 + total_seen_class = [0 for _ in range(NUM_CLASSES)] + total_correct_class = [0 for _ in range(NUM_CLASSES)] + + if FLAGS.visu: + fout = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_pred.obj'), 'w') + fout_gt = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_gt.obj'), 'w') + fout_real_color = open(os.path.join(DUMP_DIR, os.path.basename(room_path)[:-4]+'_real_color.obj'), 'w') + fout_data_label = open(out_data_label_filename, 'w') + fout_gt_label = open(out_gt_label_filename, 'w') + + current_data, current_label = indoor3d_util.room2blocks_wrapper_normalized(room_path, NUM_POINT) + current_data = current_data[:,0:NUM_POINT,:] + current_label = np.squeeze(current_label) + # Get room dimension.. + data_label = np.load(room_path) + data = data_label[:,0:6] + max_room_x = max(data[:,0]) + max_room_y = max(data[:,1]) + max_room_z = max(data[:,2]) + + file_size = current_data.shape[0] + num_batches = file_size // BATCH_SIZE + print(file_size) + + + for batch_idx in range(num_batches): + start_idx = batch_idx * BATCH_SIZE + end_idx = (batch_idx+1) * BATCH_SIZE + cur_batch_size = end_idx - start_idx + + feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], + ops['labels_pl']: current_label[start_idx:end_idx], + ops['is_training_pl']: is_training} + loss_val, pred_val = sess.run([ops['loss'], ops['pred_softmax']], + feed_dict=feed_dict) + + if FLAGS.no_clutter: + pred_label = np.argmax(pred_val[:,:,0:12], 2) # BxN + else: + pred_label = np.argmax(pred_val, 2) # BxN + + # Save prediction labels to OBJ file + for b in range(BATCH_SIZE): + pts = current_data[start_idx+b, :, :] + l = current_label[start_idx+b,:] + pts[:,6] *= max_room_x + pts[:,7] *= max_room_y + pts[:,8] *= max_room_z + pts[:,3:6] *= 255.0 + pred = pred_label[b, :] + for i in range(NUM_POINT): + color = indoor3d_util.g_label2color[pred[i]] + color_gt = indoor3d_util.g_label2color[current_label[start_idx+b, i]] + if FLAGS.visu: + fout.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color[0], color[1], color[2])) + fout_gt.write('v %f %f %f %d %d %d\n' % (pts[i,6], pts[i,7], pts[i,8], color_gt[0], color_gt[1], color_gt[2])) + fout_data_label.write('%f %f %f %d %d %d %f %d\n' % (pts[i,6], pts[i,7], pts[i,8], pts[i,3], pts[i,4], pts[i,5], pred_val[b,i,pred[i]], pred[i])) + fout_gt_label.write('%d\n' % (l[i])) + + correct = np.sum(pred_label == current_label[start_idx:end_idx,:]) + total_correct += correct + total_seen += (cur_batch_size*NUM_POINT) + loss_sum += (loss_val*BATCH_SIZE) + for i in range(start_idx, end_idx): + for j in range(NUM_POINT): + l = current_label[i, j] + total_seen_class[l] += 1 + total_correct_class[l] += (pred_label[i-start_idx, j] == l) + + log_string('eval mean loss: %f' % (loss_sum / float(total_seen/NUM_POINT))) + log_string('eval accuracy: %f'% (total_correct / float(total_seen))) + fout_data_label.close() + fout_gt_label.close() + if FLAGS.visu: + fout.close() + fout_gt.close() + return total_correct, total_seen + + +if __name__=='__main__': + with tf.Graph().as_default(): + evaluate() + LOG_FOUT.close() diff --git a/sem_seg/eval_iou_accuracy.py b/sem_seg/eval_iou_accuracy.py new file mode 100644 index 0000000..fbcc220 --- /dev/null +++ b/sem_seg/eval_iou_accuracy.py @@ -0,0 +1,44 @@ +import numpy as np + +pred_data_label_filenames = [] +for i in range(1,7): + file_name = 'log{}/output_filelist.txt'.format(i) + pred_data_label_filenames += [line.rstrip() for line in open(file_name)] + +gt_label_filenames = [f.rstrip('_pred\.txt') + '_gt.txt' for f in pred_data_label_filenames] + +num_room = len(gt_label_filenames) + +gt_classes = [0 for _ in range(13)] +positive_classes = [0 for _ in range(13)] +true_positive_classes = [0 for _ in range(13)] + +for i in range(num_room): + print(i) + data_label = np.loadtxt(pred_data_label_filenames[i]) + pred_label = data_label[:,-1] + gt_label = np.loadtxt(gt_label_filenames[i]) + print(gt_label.shape) + for j in xrange(gt_label.shape[0]): + gt_l = int(gt_label[j]) + pred_l = int(pred_label[j]) + gt_classes[gt_l] += 1 + positive_classes[pred_l] += 1 + true_positive_classes[gt_l] += int(gt_l==pred_l) + + +print(gt_classes) +print(positive_classes) +print(true_positive_classes) + +print('Overall accuracy: {0}'.format(sum(true_positive_classes)/float(sum(positive_classes)))) + +print 'IoU:' +iou_list = [] +for i in range(13): + iou = true_positive_classes[i]/float(gt_classes[i]+positive_classes[i]-true_positive_classes[i]) + print(iou) + iou_list.append(iou) + +print 'avg IoU:' +print(sum(iou_list)/13.0) diff --git a/sem_seg/model.py b/sem_seg/model.py new file mode 100644 index 0000000..9adade3 --- /dev/null +++ b/sem_seg/model.py @@ -0,0 +1,120 @@ +import tensorflow as tf +import math +import time +import numpy as np +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.dirname(BASE_DIR) +sys.path.append(os.path.join(ROOT_DIR, 'utils')) +sys.path.append(os.path.join(BASE_DIR, '../models')) +import tf_util + +def placeholder_inputs(batch_size, num_point): + pointclouds_pl = tf.placeholder(tf.float32, + shape=(batch_size, num_point, 9)) + labels_pl = tf.placeholder(tf.int32, + shape=(batch_size, num_point)) + return pointclouds_pl, labels_pl + +def get_model(point_cloud, is_training, bn_decay=None): + """ ConvNet baseline, input is BxNx9 gray image """ + batch_size = point_cloud.get_shape()[0].value + num_point = point_cloud.get_shape()[1].value + input_image = tf.expand_dims(point_cloud, -1) + + k = 30 + + adj = tf_util.pairwise_distance(point_cloud[:, :, 6:]) + nn_idx = tf_util.knn(adj, k=k) # (batch, num_points, k) + edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k) + + out1 = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv1', bn_decay=bn_decay, is_dist=True) + + out2 = tf_util.conv2d(out1, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv2', bn_decay=bn_decay, is_dist=True) + + net_max_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) + net_mean_1 = tf.reduce_mean(out2, axis=-2, keep_dims=True) + + out3 = tf_util.conv2d(tf.concat([net_max_1, net_mean_1], axis=-1), 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv3', bn_decay=bn_decay, is_dist=True) + + adj = tf_util.pairwise_distance(tf.squeeze(out3, axis=-2)) + nn_idx = tf_util.knn(adj, k=k) + edge_feature = tf_util.get_edge_feature(out3, nn_idx=nn_idx, k=k) + + out4 = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv4', bn_decay=bn_decay, is_dist=True) + + net_max_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) + net_mean_2 = tf.reduce_mean(out4, axis=-2, keep_dims=True) + + out5 = tf_util.conv2d(tf.concat([net_max_2, net_mean_2], axis=-1), 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv5', bn_decay=bn_decay, is_dist=True) + + adj = tf_util.pairwise_distance(tf.squeeze(out5, axis=-2)) + nn_idx = tf_util.knn(adj, k=k) + edge_feature = tf_util.get_edge_feature(out5, nn_idx=nn_idx, k=k) + + out6 = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv6', bn_decay=bn_decay, is_dist=True) + + net_max_3 = tf.reduce_max(out6, axis=-2, keep_dims=True) + net_mean_3 = tf.reduce_mean(out6, axis=-2, keep_dims=True) + + out7 = tf_util.conv2d(tf.concat([net_max_3, net_mean_3], axis=-1), 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv7', bn_decay=bn_decay, is_dist=True) + + out8 = tf_util.conv2d(tf.concat([out3, out5, out7], axis=-1), 1024, [1, 1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='adj_conv8', bn_decay=bn_decay, is_dist=True) + + out_max = tf_util.max_pool2d(out8, [num_point,1], padding='VALID', scope='maxpool') + + expand = tf.tile(out_max, [1, num_point, 1, 1]) + + concat = tf.concat(axis=3, values=[expand, + net_max_1, + net_mean_1, + out3, + net_max_2, + net_mean_2, + out5, + net_max_3, + net_mean_3, + out7, + out8]) + + # CONV + net = tf_util.conv2d(concat, 512, [1,1], padding='VALID', stride=[1,1], + bn=True, is_training=is_training, scope='seg/conv1', is_dist=True) + net = tf_util.conv2d(net, 256, [1,1], padding='VALID', stride=[1,1], + bn=True, is_training=is_training, scope='seg/conv2', is_dist=True) + net = tf_util.dropout(net, keep_prob=0.7, is_training=is_training, scope='dp1') + net = tf_util.conv2d(net, 13, [1,1], padding='VALID', stride=[1,1], + activation_fn=None, scope='seg/conv3', is_dist=True) + net = tf.squeeze(net, [2]) + + return net + +def get_loss(pred, label): + """ pred: B,N,13; label: B,N """ + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=pred, labels=label) + return tf.reduce_mean(loss) diff --git a/sem_seg/test_job.sh b/sem_seg/test_job.sh new file mode 100644 index 0000000..0ff781e --- /dev/null +++ b/sem_seg/test_job.sh @@ -0,0 +1,6 @@ +python batch_inference.py --model_path log1/epoch_60.ckpt --dump_dir log1/dump --output_filelist log1/output_filelist.txt --room_data_filelist meta/area1_data_label.txt +python batch_inference.py --model_path log2/epoch_60.ckpt --dump_dir log2/dump --output_filelist log2/output_filelist.txt --room_data_filelist meta/area2_data_label.txt +python batch_inference.py --model_path log3/epoch_60.ckpt --dump_dir log3/dump --output_filelist log3/output_filelist.txt --room_data_filelist meta/area3_data_label.txt +python batch_inference.py --model_path log4/epoch_60.ckpt --dump_dir log4/dump --output_filelist log4/output_filelist.txt --room_data_filelist meta/area4_data_label.txt +python batch_inference.py --model_path log5/epoch_60.ckpt --dump_dir log5/dump --output_filelist log5/output_filelist.txt --room_data_filelist meta/area5_data_label.txt +python batch_inference.py --model_path log6/epoch_60.ckpt --dump_dir log6/dump --output_filelist log6/output_filelist.txt --room_data_filelist meta/area6_data_label.txt \ No newline at end of file diff --git a/sem_seg/train.py b/sem_seg/train.py new file mode 100644 index 0000000..bdec182 --- /dev/null +++ b/sem_seg/train.py @@ -0,0 +1,286 @@ +import argparse +import math +import h5py +import numpy as np +import tensorflow as tf +import socket + +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +ROOT_DIR = os.path.dirname(BASE_DIR) +sys.path.append(BASE_DIR) +sys.path.append(ROOT_DIR) +sys.path.append(os.path.join(ROOT_DIR, 'utils')) +import provider +import tf_util +from model import * + +parser = argparse.ArgumentParser() +parser.add_argument('--num_gpu', type=int, default=2, help='the number of GPUs to use [default: 2]') +parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') +parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') +parser.add_argument('--max_epoch', type=int, default=101, help='Epoch to run [default: 50]') +parser.add_argument('--batch_size', type=int, default=12, help='Batch Size during training for each GPU [default: 24]') +parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') +parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') +parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') +parser.add_argument('--decay_step', type=int, default=300000, help='Decay step for lr decay [default: 300000]') +parser.add_argument('--decay_rate', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') +parser.add_argument('--test_area', type=int, default=6, help='Which area to use for test, option: 1-6 [default: 6]') +FLAGS = parser.parse_args() + +TOWER_NAME = 'tower' + +BATCH_SIZE = FLAGS.batch_size +NUM_POINT = FLAGS.num_point +MAX_EPOCH = FLAGS.max_epoch +NUM_POINT = FLAGS.num_point +BASE_LEARNING_RATE = FLAGS.learning_rate +MOMENTUM = FLAGS.momentum +OPTIMIZER = FLAGS.optimizer +DECAY_STEP = FLAGS.decay_step +DECAY_RATE = FLAGS.decay_rate + +LOG_DIR = FLAGS.log_dir +if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) +os.system('cp model.py %s' % (LOG_DIR)) +os.system('cp train.py %s' % (LOG_DIR)) +LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') +LOG_FOUT.write(str(FLAGS)+'\n') + +MAX_NUM_POINT = 4096 +NUM_CLASSES = 13 + +BN_INIT_DECAY = 0.5 +BN_DECAY_DECAY_RATE = 0.5 +BN_DECAY_DECAY_STEP = float(DECAY_STEP) +BN_DECAY_CLIP = 0.99 + +HOSTNAME = socket.gethostname() + +ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt') +room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')] +print len(room_filelist) + +# Load ALL data +data_batch_list = [] +label_batch_list = [] +for h5_filename in ALL_FILES: + data_batch, label_batch = provider.loadDataFile(h5_filename) + data_batch_list.append(data_batch) + label_batch_list.append(label_batch) +data_batches = np.concatenate(data_batch_list, 0) +label_batches = np.concatenate(label_batch_list, 0) +print(data_batches.shape) +print(label_batches.shape) + +test_area = 'Area_'+str(FLAGS.test_area) +train_idxs = [] +test_idxs = [] +for i,room_name in enumerate(room_filelist): + if test_area in room_name: + test_idxs.append(i) + else: + train_idxs.append(i) + +train_data = data_batches[train_idxs,...] +train_label = label_batches[train_idxs] +test_data = data_batches[test_idxs,...] +test_label = label_batches[test_idxs] +print(train_data.shape, train_label.shape) +print(test_data.shape, test_label.shape) + + +def log_string(out_str): + LOG_FOUT.write(out_str+'\n') + LOG_FOUT.flush() + print(out_str) + + +def get_learning_rate(batch): + learning_rate = tf.train.exponential_decay( + BASE_LEARNING_RATE, # Base learning rate. + batch * BATCH_SIZE, # Current index into the dataset. + DECAY_STEP, # Decay step. + DECAY_RATE, # Decay rate. + staircase=True) + learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!! + return learning_rate + +def get_bn_decay(batch): + bn_momentum = tf.train.exponential_decay( + BN_INIT_DECAY, + batch*BATCH_SIZE, + BN_DECAY_DECAY_STEP, + BN_DECAY_DECAY_RATE, + staircase=True) + bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) + return bn_decay + +def average_gradients(tower_grads): + """Calculate average gradient for each shared variable across all towers. + + Note that this function provides a synchronization point across all towers. + + Args: + tower_grads: List of lists of (gradient, variable) tuples. The outer list + is over individual gradients. The inner list is over the gradient + calculation for each tower. + Returns: + List of pairs of (gradient, variable) where the gradient has been + averaged across all towers. + """ + average_grads = [] + for grad_and_vars in zip(*tower_grads): + # Note that each grad_and_vars looks like the following: + # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) + grads = [] + for g, _ in grad_and_vars: + expanded_g = tf.expand_dims(g, 0) + grads.append(expanded_g) + + # Average over the 'tower' dimension. + grad = tf.concat(grads, 0) + grad = tf.reduce_mean(grad, 0) + + # Keep in mind that the Variables are redundant because they are shared + # across towers. So .. we will just return the first tower's pointer to + # the Variable. + v = grad_and_vars[0][1] + grad_and_var = (grad, v) + average_grads.append(grad_and_var) + return average_grads + +def train(): + with tf.Graph().as_default(), tf.device('/cpu:0'): + batch = tf.Variable(0, trainable=False) + + bn_decay = get_bn_decay(batch) + tf.summary.scalar('bn_decay', bn_decay) + + learning_rate = get_learning_rate(batch) + tf.summary.scalar('learning_rate', learning_rate) + + trainer = tf.train.AdamOptimizer(learning_rate) + + tower_grads = [] + pointclouds_phs = [] + labels_phs = [] + is_training_phs =[] + + with tf.variable_scope(tf.get_variable_scope()): + for i in xrange(FLAGS.num_gpu): + with tf.device('/gpu:%d' % i): + with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope: + + pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) + is_training_pl = tf.placeholder(tf.bool, shape=()) + + pointclouds_phs.append(pointclouds_pl) + labels_phs.append(labels_pl) + is_training_phs.append(is_training_pl) + + pred = get_model(pointclouds_phs[-1], is_training_phs[-1], bn_decay=bn_decay) + loss = get_loss(pred, labels_phs[-1]) + tf.summary.scalar('loss', loss) + + correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_phs[-1])) + accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT) + tf.summary.scalar('accuracy', accuracy) + + tf.get_variable_scope().reuse_variables() + + grads = trainer.compute_gradients(loss) + + tower_grads.append(grads) + + grads = average_gradients(tower_grads) + + train_op = trainer.apply_gradients(grads, global_step=batch) + + saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10) + + # Create a session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + sess = tf.Session(config=config) + + # Add summary writers + merged = tf.summary.merge_all() + train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), + sess.graph) + test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) + + # Init variables for two GPUs + init = tf.group(tf.global_variables_initializer(), + tf.local_variables_initializer()) + sess.run(init) + + ops = {'pointclouds_phs': pointclouds_phs, + 'labels_phs': labels_phs, + 'is_training_phs': is_training_phs, + 'pred': pred, + 'loss': loss, + 'train_op': train_op, + 'merged': merged, + 'step': batch} + + for epoch in range(MAX_EPOCH): + log_string('**** EPOCH %03d ****' % (epoch)) + sys.stdout.flush() + + train_one_epoch(sess, ops, train_writer) + + # Save the variables to disk. + if epoch % 10 == 0: + save_path = saver.save(sess, os.path.join(LOG_DIR,'epoch_' + str(epoch)+'.ckpt')) + log_string("Model saved in file: %s" % save_path) + + + +def train_one_epoch(sess, ops, train_writer): + """ ops: dict mapping from string to tf ops """ + is_training = True + + log_string('----') + current_data, current_label, _ = provider.shuffle_data(train_data[:,0:NUM_POINT,:], train_label) + + file_size = current_data.shape[0] + num_batches = file_size // (FLAGS.num_gpu * BATCH_SIZE) + + total_correct = 0 + total_seen = 0 + loss_sum = 0 + + for batch_idx in range(num_batches): + if batch_idx % 100 == 0: + print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches)) + start_idx_0 = batch_idx * BATCH_SIZE + end_idx_0 = (batch_idx+1) * BATCH_SIZE + start_idx_1 = (batch_idx+1) * BATCH_SIZE + end_idx_1 = (batch_idx+2) * BATCH_SIZE + + + feed_dict = {ops['pointclouds_phs'][0]: current_data[start_idx_0:end_idx_0, :, :], + ops['pointclouds_phs'][1]: current_data[start_idx_1:end_idx_1, :, :], + ops['labels_phs'][0]: current_label[start_idx_0:end_idx_0], + ops['labels_phs'][1]: current_label[start_idx_1:end_idx_1], + ops['is_training_phs'][0]: is_training, + ops['is_training_phs'][1]: is_training} + summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], + feed_dict=feed_dict) + train_writer.add_summary(summary, step) + pred_val = np.argmax(pred_val, 2) + correct = np.sum(pred_val == current_label[start_idx_1:end_idx_1]) + total_correct += correct + total_seen += (BATCH_SIZE*NUM_POINT) + loss_sum += loss_val + + log_string('mean loss: %f' % (loss_sum / float(num_batches))) + log_string('accuracy: %f' % (total_correct / float(total_seen))) + +if __name__ == "__main__": + train() + LOG_FOUT.close() diff --git a/sem_seg/train_job.sh b/sem_seg/train_job.sh new file mode 100644 index 0000000..26c23bf --- /dev/null +++ b/sem_seg/train_job.sh @@ -0,0 +1,6 @@ +python train.py --log_dir log1 --test_area 1 +python train.py --log_dir log2 --test_area 2 +python train.py --log_dir log3 --test_area 3 +python train.py --log_dir log4 --test_area 4 +python train.py --log_dir log5 --test_area 5 +python train.py --log_dir log6 --test_area 6 \ No newline at end of file