From 2fee5f8abcc95ed154d1fad6db8a0ff148fbe6d2 Mon Sep 17 00:00:00 2001 From: Yue Wang Date: Sat, 3 Feb 2018 01:48:00 -0500 Subject: [PATCH] add edge conv --- models/dgcnn.py | 150 +++++++ models/transform_nets.py | 55 +++ provider.py | 149 +++++++ train.py | 265 +++++++++++ utils/data_prep_util.py | 145 +++++++ utils/eulerangles.py | 418 ++++++++++++++++++ utils/pc_util.py | 198 +++++++++ utils/plyfile.py | 916 +++++++++++++++++++++++++++++++++++++++ utils/tf_util.py | 639 +++++++++++++++++++++++++++ 9 files changed, 2935 insertions(+) create mode 100644 models/dgcnn.py create mode 100644 models/transform_nets.py create mode 100644 provider.py create mode 100644 train.py create mode 100644 utils/data_prep_util.py create mode 100644 utils/eulerangles.py create mode 100644 utils/pc_util.py create mode 100644 utils/plyfile.py create mode 100644 utils/tf_util.py diff --git a/models/dgcnn.py b/models/dgcnn.py new file mode 100644 index 0000000..5b768ea --- /dev/null +++ b/models/dgcnn.py @@ -0,0 +1,150 @@ +import tensorflow as tf +import numpy as np +import math +import sys +import os +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) +sys.path.append(os.path.join(BASE_DIR, '../utils')) +sys.path.append(os.path.join(BASE_DIR, '../../utils')) +import tf_util +from transform_nets import input_transform_net + + +def placeholder_inputs(batch_size, num_point): + pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3)) + labels_pl = tf.placeholder(tf.int32, shape=(batch_size)) + return pointclouds_pl, labels_pl + + +def get_model(point_cloud, is_training, bn_decay=None): + """ Classification PointNet, input is BxNx3, output Bx40 """ + batch_size = point_cloud.get_shape()[0].value + num_point = point_cloud.get_shape()[1].value + end_points = {} + k = 20 + + adj_matrix = tf_util.pairwise_distance(point_cloud) + nn_idx = tf_util.knn(adj_matrix, k=k) + edge_feature = tf_util.get_edge_feature(point_cloud, nn_idx=nn_idx, k=k) + + with tf.variable_scope('transform_net1') as sc: + transform = input_transform_net(edge_feature, is_training, bn_decay, K=3) + + point_cloud_transformed = tf.matmul(point_cloud, transform) + adj_matrix = tf_util.pairwise_distance(point_cloud_transformed) + nn_idx = tf_util.knn(adj_matrix, k=k) + edge_feature = tf_util.get_edge_feature(point_cloud_transformed, nn_idx=nn_idx, k=k) + + net = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='dgcnn1', bn_decay=bn_decay) + net = tf.reduce_max(net, axis=-2, keep_dims=True) + net1 = net + + adj_matrix = tf_util.pairwise_distance(net) + nn_idx = tf_util.knn(adj_matrix, k=k) + edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) + + net = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='dgcnn2', bn_decay=bn_decay) + net = tf.reduce_max(net, axis=-2, keep_dims=True) + net2 = net + + adj_matrix = tf_util.pairwise_distance(net) + nn_idx = tf_util.knn(adj_matrix, k=k) + edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) + + net = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='dgcnn3', bn_decay=bn_decay) + net = tf.reduce_max(net, axis=-2, keep_dims=True) + net3 = net + + adj_matrix = tf_util.pairwise_distance(net) + nn_idx = tf_util.knn(adj_matrix, k=k) + edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k) + + net = tf_util.conv2d(edge_feature, 128, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='dgcnn4', bn_decay=bn_decay) + net = tf.reduce_max(net, axis=-2, keep_dims=True) + net4 = net + + net = tf_util.conv2d(tf.concat([net1, net2, net3, net4], axis=-1), 1024, [1, 1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='agg', bn_decay=bn_decay) + + net = tf.reduce_max(net, axis=1, keep_dims=True) + + # MLP on global point cloud vector + net = tf.reshape(net, [batch_size, -1]) + net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, + scope='fc1', bn_decay=bn_decay) + net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, + scope='dp1') + net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, + scope='fc2', bn_decay=bn_decay) + net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training, + scope='dp2') + net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3') + + return net, end_points + + +def get_loss(pred, label, end_points): + """ pred: B*NUM_CLASSES, + label: B, """ + labels = tf.one_hot(indices=label, depth=40) + loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=pred, label_smoothing=0.2) + classify_loss = tf.reduce_mean(loss) + return classify_loss + + +if __name__=='__main__': + batch_size = 2 + num_pt = 124 + pos_dim = 3 + + input_feed = np.random.rand(batch_size, num_pt, pos_dim) + label_feed = np.random.rand(batch_size) + label_feed[label_feed>=0.5] = 1 + label_feed[label_feed<0.5] = 0 + label_feed = label_feed.astype(np.int32) + + # # np.save('./debug/input_feed.npy', input_feed) + # input_feed = np.load('./debug/input_feed.npy') + # print input_feed + + with tf.Graph().as_default(): + input_pl, label_pl = placeholder_inputs(batch_size, num_pt) + pos, ftr = get_model(input_pl, tf.constant(True)) + # loss = get_loss(logits, label_pl, None) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + feed_dict = {input_pl: input_feed, label_pl: label_feed} + res1, res2 = sess.run([pos, ftr], feed_dict=feed_dict) + print res1.shape + print res1 + + print res2.shape + print res2 + + + + + + + + + + + + diff --git a/models/transform_nets.py b/models/transform_nets.py new file mode 100644 index 0000000..23faf6d --- /dev/null +++ b/models/transform_nets.py @@ -0,0 +1,55 @@ +import tensorflow as tf +import numpy as np +import sys +import os +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) +sys.path.append(os.path.join(BASE_DIR, '../utils')) +import tf_util + +def input_transform_net(edge_feature, is_training, bn_decay=None, K=3): + """ Input (XYZ) Transform Net, input is BxNx3 gray image + Return: + Transformation matrix of size 3xK """ + batch_size = edge_feature.get_shape()[0].value + num_point = edge_feature.get_shape()[1].value + + # input_image = tf.expand_dims(point_cloud, -1) + net = tf_util.conv2d(edge_feature, 64, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='tconv1', bn_decay=bn_decay) + net = tf_util.conv2d(net, 128, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='tconv2', bn_decay=bn_decay) + + net = tf.reduce_max(net, axis=-2, keep_dims=True) + + net = tf_util.conv2d(net, 1024, [1,1], + padding='VALID', stride=[1,1], + bn=True, is_training=is_training, + scope='tconv3', bn_decay=bn_decay) + net = tf_util.max_pool2d(net, [num_point,1], + padding='VALID', scope='tmaxpool') + + net = tf.reshape(net, [batch_size, -1]) + net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training, + scope='tfc1', bn_decay=bn_decay) + net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training, + scope='tfc2', bn_decay=bn_decay) + + with tf.variable_scope('transform_XYZ') as sc: + # assert(K==3) + weights = tf.get_variable('weights', [256, K*K], + initializer=tf.constant_initializer(0.0), + dtype=tf.float32) + biases = tf.get_variable('biases', [K*K], + initializer=tf.constant_initializer(0.0), + dtype=tf.float32) + biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32) + transform = tf.matmul(net, weights) + transform = tf.nn.bias_add(transform, biases) + + transform = tf.reshape(transform, [batch_size, K, K]) + return transform \ No newline at end of file diff --git a/provider.py b/provider.py new file mode 100644 index 0000000..d148e5b --- /dev/null +++ b/provider.py @@ -0,0 +1,149 @@ +import os +import sys +import numpy as np +import h5py +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) + +# Download dataset for point cloud classification +DATA_DIR = os.path.join(BASE_DIR, 'data') +if not os.path.exists(DATA_DIR): + os.mkdir(DATA_DIR) +if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')): + www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip' + zipfile = os.path.basename(www) + os.system('wget %s; unzip %s' % (www, zipfile)) + os.system('mv %s %s' % (zipfile[:-4], DATA_DIR)) + os.system('rm %s' % (zipfile)) + + +def shuffle_data(data, labels): + """ Shuffle data and labels. + Input: + data: B,N,... numpy array + label: B,... numpy array + Return: + shuffled data, label and shuffle indices + """ + idx = np.arange(len(labels)) + np.random.shuffle(idx) + return data[idx, ...], labels[idx], idx + + +def rotate_point_cloud(batch_data): + """ Randomly rotate the point clouds to augument the dataset + rotation is per shape based along up direction + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in xrange(batch_data.shape[0]): + rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + return rotated_data + + +def rotate_point_cloud_by_angle(batch_data, rotation_angle): + """ Rotate the point cloud along up direction with certain angle. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in xrange(batch_data.shape[0]): + #rotation_angle = np.random.uniform() * 2 * np.pi + cosval = np.cos(rotation_angle) + sinval = np.sin(rotation_angle) + rotation_matrix = np.array([[cosval, 0, sinval], + [0, 1, 0], + [-sinval, 0, cosval]]) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) + return rotated_data + + +def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): + """ Randomly perturb the point clouds by small rotations + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, rotated batch of point clouds + """ + rotated_data = np.zeros(batch_data.shape, dtype=np.float32) + for k in xrange(batch_data.shape[0]): + angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) + Rx = np.array([[1,0,0], + [0,np.cos(angles[0]),-np.sin(angles[0])], + [0,np.sin(angles[0]),np.cos(angles[0])]]) + Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], + [0,1,0], + [-np.sin(angles[1]),0,np.cos(angles[1])]]) + Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], + [np.sin(angles[2]),np.cos(angles[2]),0], + [0,0,1]]) + R = np.dot(Rz, np.dot(Ry,Rx)) + shape_pc = batch_data[k, ...] + rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) + return rotated_data + + +def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): + """ Randomly jitter points. jittering is per point. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, jittered batch of point clouds + """ + B, N, C = batch_data.shape + assert(clip > 0) + jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) + jittered_data += batch_data + return jittered_data + +def shift_point_cloud(batch_data, shift_range=0.1): + """ Randomly shift point cloud. Shift is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, shifted batch of point clouds + """ + B, N, C = batch_data.shape + shifts = np.random.uniform(-shift_range, shift_range, (B,3)) + for batch_index in range(B): + batch_data[batch_index,:,:] += shifts[batch_index,:] + return batch_data + + +def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25): + """ Randomly scale the point cloud. Scale is per point cloud. + Input: + BxNx3 array, original batch of point clouds + Return: + BxNx3 array, scaled batch of point clouds + """ + B, N, C = batch_data.shape + scales = np.random.uniform(scale_low, scale_high, B) + for batch_index in range(B): + batch_data[batch_index,:,:] *= scales[batch_index] + return batch_data + +def getDataFiles(list_filename): + return [line.rstrip() for line in open(list_filename)] + +def load_h5(h5_filename): + f = h5py.File(h5_filename) + data = f['data'][:] + label = f['label'][:] + return (data, label) + +def loadDataFile(filename): + return load_h5(filename) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..62337c3 --- /dev/null +++ b/train.py @@ -0,0 +1,265 @@ +import argparse +import math +import h5py +import numpy as np +import tensorflow as tf +import socket +import importlib +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) +sys.path.append(os.path.join(BASE_DIR, 'models')) +sys.path.append(os.path.join(BASE_DIR, 'utils')) +import provider +import tf_util + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]') +parser.add_argument('--model', default='dgcnn', help='Model name: dgcnn') +parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') +parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]') +parser.add_argument('--max_epoch', type=int, default=250, help='Epoch to run [default: 250]') +parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]') +parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') +parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') +parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') +parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]') +parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]') +FLAGS = parser.parse_args() + + +BATCH_SIZE = FLAGS.batch_size +NUM_POINT = FLAGS.num_point +MAX_EPOCH = FLAGS.max_epoch +BASE_LEARNING_RATE = FLAGS.learning_rate +GPU_INDEX = FLAGS.gpu +MOMENTUM = FLAGS.momentum +OPTIMIZER = FLAGS.optimizer +DECAY_STEP = FLAGS.decay_step +DECAY_RATE = FLAGS.decay_rate + +MODEL = importlib.import_module(FLAGS.model) # import network module +MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py') +LOG_DIR = FLAGS.log_dir +if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR) +os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def +os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure +LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w') +LOG_FOUT.write(str(FLAGS)+'\n') + +MAX_NUM_POINT = 2048 +NUM_CLASSES = 40 + +BN_INIT_DECAY = 0.5 +BN_DECAY_DECAY_RATE = 0.5 +BN_DECAY_DECAY_STEP = float(DECAY_STEP) +BN_DECAY_CLIP = 0.99 + +HOSTNAME = socket.gethostname() + +# ModelNet40 official train/test split +TRAIN_FILES = provider.getDataFiles( \ + os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt')) +TEST_FILES = provider.getDataFiles(\ + os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt')) + +def log_string(out_str): + LOG_FOUT.write(out_str+'\n') + LOG_FOUT.flush() + print(out_str) + + +def get_learning_rate(batch): + learning_rate = tf.train.exponential_decay( + BASE_LEARNING_RATE, # Base learning rate. + batch * BATCH_SIZE, # Current index into the dataset. + DECAY_STEP, # Decay step. + DECAY_RATE, # Decay rate. + staircase=True) + learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE! + return learning_rate + +def get_bn_decay(batch): + bn_momentum = tf.train.exponential_decay( + BN_INIT_DECAY, + batch*BATCH_SIZE, + BN_DECAY_DECAY_STEP, + BN_DECAY_DECAY_RATE, + staircase=True) + bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum) + return bn_decay + +def train(): + with tf.Graph().as_default(): + with tf.device('/gpu:'+str(GPU_INDEX)): + pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) + is_training_pl = tf.placeholder(tf.bool, shape=()) + print(is_training_pl) + + # Note the global_step=batch parameter to minimize. + # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains. + batch = tf.Variable(0) + bn_decay = get_bn_decay(batch) + tf.summary.scalar('bn_decay', bn_decay) + + # Get model and loss + pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay) + loss = MODEL.get_loss(pred, labels_pl, end_points) + tf.summary.scalar('loss', loss) + + correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl)) + accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE) + tf.summary.scalar('accuracy', accuracy) + + # Get training operator + learning_rate = get_learning_rate(batch) + tf.summary.scalar('learning_rate', learning_rate) + if OPTIMIZER == 'momentum': + optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM) + elif OPTIMIZER == 'adam': + optimizer = tf.train.AdamOptimizer(learning_rate) + train_op = optimizer.minimize(loss, global_step=batch) + + # Add ops to save and restore all the variables. + saver = tf.train.Saver() + + # Create a session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + config.log_device_placement = False + sess = tf.Session(config=config) + + # Add summary writers + #merged = tf.merge_all_summaries() + merged = tf.summary.merge_all() + train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), + sess.graph) + test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) + + # Init variables + init = tf.global_variables_initializer() + # To fix the bug introduced in TF 0.12.1 as in + # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1 + #sess.run(init) + sess.run(init, {is_training_pl: True}) + + ops = {'pointclouds_pl': pointclouds_pl, + 'labels_pl': labels_pl, + 'is_training_pl': is_training_pl, + 'pred': pred, + 'loss': loss, + 'train_op': train_op, + 'merged': merged, + 'step': batch} + + for epoch in range(MAX_EPOCH): + log_string('**** EPOCH %03d ****' % (epoch)) + sys.stdout.flush() + + train_one_epoch(sess, ops, train_writer) + eval_one_epoch(sess, ops, test_writer) + + # Save the variables to disk. + if epoch % 10 == 0: + save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt")) + log_string("Model saved in file: %s" % save_path) + + + +def train_one_epoch(sess, ops, train_writer): + """ ops: dict mapping from string to tf ops """ + is_training = True + + # Shuffle train files + train_file_idxs = np.arange(0, len(TRAIN_FILES)) + np.random.shuffle(train_file_idxs) + + for fn in range(len(TRAIN_FILES)): + log_string('----' + str(fn) + '-----') + current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]]) + current_data = current_data[:,0:NUM_POINT,:] + current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label)) + current_label = np.squeeze(current_label) + + file_size = current_data.shape[0] + num_batches = file_size // BATCH_SIZE + + total_correct = 0 + total_seen = 0 + loss_sum = 0 + + for batch_idx in range(num_batches): + start_idx = batch_idx * BATCH_SIZE + end_idx = (batch_idx+1) * BATCH_SIZE + + # Augment batched point clouds by rotation and jittering + rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :]) + jittered_data = provider.jitter_point_cloud(rotated_data) + jittered_data = provider.random_scale_point_cloud(jittered_data) + jittered_data = provider.rotate_perturbation_point_cloud(jittered_data) + jittered_data = provider.shift_point_cloud(jittered_data) + + feed_dict = {ops['pointclouds_pl']: jittered_data, + ops['labels_pl']: current_label[start_idx:end_idx], + ops['is_training_pl']: is_training,} + summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], + ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) + train_writer.add_summary(summary, step) + pred_val = np.argmax(pred_val, 1) + correct = np.sum(pred_val == current_label[start_idx:end_idx]) + total_correct += correct + total_seen += BATCH_SIZE + loss_sum += loss_val + + log_string('mean loss: %f' % (loss_sum / float(num_batches))) + log_string('accuracy: %f' % (total_correct / float(total_seen))) + + +def eval_one_epoch(sess, ops, test_writer): + """ ops: dict mapping from string to tf ops """ + is_training = False + total_correct = 0 + total_seen = 0 + loss_sum = 0 + total_seen_class = [0 for _ in range(NUM_CLASSES)] + total_correct_class = [0 for _ in range(NUM_CLASSES)] + + for fn in range(len(TEST_FILES)): + log_string('----' + str(fn) + '-----') + current_data, current_label = provider.loadDataFile(TEST_FILES[fn]) + current_data = current_data[:,0:NUM_POINT,:] + current_label = np.squeeze(current_label) + + file_size = current_data.shape[0] + num_batches = file_size // BATCH_SIZE + + for batch_idx in range(num_batches): + start_idx = batch_idx * BATCH_SIZE + end_idx = (batch_idx+1) * BATCH_SIZE + + feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :], + ops['labels_pl']: current_label[start_idx:end_idx], + ops['is_training_pl']: is_training} + summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'], + ops['loss'], ops['pred']], feed_dict=feed_dict) + pred_val = np.argmax(pred_val, 1) + correct = np.sum(pred_val == current_label[start_idx:end_idx]) + total_correct += correct + total_seen += BATCH_SIZE + loss_sum += (loss_val*BATCH_SIZE) + for i in range(start_idx, end_idx): + l = current_label[i] + total_seen_class[l] += 1 + total_correct_class[l] += (pred_val[i-start_idx] == l) + + log_string('eval mean loss: %f' % (loss_sum / float(total_seen))) + log_string('eval accuracy: %f'% (total_correct / float(total_seen))) + log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float)))) + + + +if __name__ == "__main__": + train() + LOG_FOUT.close() diff --git a/utils/data_prep_util.py b/utils/data_prep_util.py new file mode 100644 index 0000000..53d32f1 --- /dev/null +++ b/utils/data_prep_util.py @@ -0,0 +1,145 @@ +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) +from plyfile import (PlyData, PlyElement, make2d, PlyParseError, PlyProperty) +import numpy as np +import h5py + +SAMPLING_BIN = os.path.join(BASE_DIR, 'third_party/mesh_sampling/build/pcsample') + +SAMPLING_POINT_NUM = 2048 +SAMPLING_LEAF_SIZE = 0.005 + +MODELNET40_PATH = '../datasets/modelnet40' +def export_ply(pc, filename): + vertex = np.zeros(pc.shape[0], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + for i in range(pc.shape[0]): + vertex[i] = (pc[i][0], pc[i][1], pc[i][2]) + ply_out = PlyData([PlyElement.describe(vertex, 'vertex', comments=['vertices'])]) + ply_out.write(filename) + +# Sample points on the obj shape +def get_sampling_command(obj_filename, ply_filename): + cmd = SAMPLING_BIN + ' ' + obj_filename + cmd += ' ' + ply_filename + cmd += ' -n_samples %d ' % SAMPLING_POINT_NUM + cmd += ' -leaf_size %f ' % SAMPLING_LEAF_SIZE + return cmd + +# -------------------------------------------------------------- +# Following are the helper functions to load MODELNET40 shapes +# -------------------------------------------------------------- + +# Read in the list of categories in MODELNET40 +def get_category_names(): + shape_names_file = os.path.join(MODELNET40_PATH, 'shape_names.txt') + shape_names = [line.rstrip() for line in open(shape_names_file)] + return shape_names + +# Return all the filepaths for the shapes in MODELNET40 +def get_obj_filenames(): + obj_filelist_file = os.path.join(MODELNET40_PATH, 'filelist.txt') + obj_filenames = [os.path.join(MODELNET40_PATH, line.rstrip()) for line in open(obj_filelist_file)] + print('Got %d obj files in modelnet40.' % len(obj_filenames)) + return obj_filenames + +# Helper function to create the father folder and all subdir folders if not exist +def batch_mkdir(output_folder, subdir_list): + if not os.path.exists(output_folder): + os.mkdir(output_folder) + for subdir in subdir_list: + if not os.path.exists(os.path.join(output_folder, subdir)): + os.mkdir(os.path.join(output_folder, subdir)) + +# ---------------------------------------------------------------- +# Following are the helper functions to load save/load HDF5 files +# ---------------------------------------------------------------- + +# Write numpy array data and label to h5_filename +def save_h5_data_label_normal(h5_filename, data, label, normal, + data_dtype='float32', label_dtype='uint8', noral_dtype='float32'): + h5_fout = h5py.File(h5_filename) + h5_fout.create_dataset( + 'data', data=data, + compression='gzip', compression_opts=4, + dtype=data_dtype) + h5_fout.create_dataset( + 'normal', data=normal, + compression='gzip', compression_opts=4, + dtype=normal_dtype) + h5_fout.create_dataset( + 'label', data=label, + compression='gzip', compression_opts=1, + dtype=label_dtype) + h5_fout.close() + + +# Write numpy array data and label to h5_filename +def save_h5(h5_filename, data, label, data_dtype='uint8', label_dtype='uint8'): + h5_fout = h5py.File(h5_filename) + h5_fout.create_dataset( + 'data', data=data, + compression='gzip', compression_opts=4, + dtype=data_dtype) + h5_fout.create_dataset( + 'label', data=label, + compression='gzip', compression_opts=1, + dtype=label_dtype) + h5_fout.close() + +# Read numpy array data and label from h5_filename +def load_h5_data_label_normal(h5_filename): + f = h5py.File(h5_filename) + data = f['data'][:] + label = f['label'][:] + normal = f['normal'][:] + return (data, label, normal) + +# Read numpy array data and label from h5_filename +def load_h5_data_label_seg(h5_filename): + f = h5py.File(h5_filename) + data = f['data'][:] + label = f['label'][:] + seg = f['pid'][:] + return (data, label, seg) + +# Read numpy array data and label from h5_filename +def load_h5(h5_filename): + f = h5py.File(h5_filename) + data = f['data'][:] + label = f['label'][:] + return (data, label) + +# ---------------------------------------------------------------- +# Following are the helper functions to load save/load PLY files +# ---------------------------------------------------------------- + +# Load PLY file +def load_ply_data(filename, point_num): + plydata = PlyData.read(filename) + pc = plydata['vertex'].data[:point_num] + pc_array = np.array([[x, y, z] for x,y,z in pc]) + return pc_array + +# Load PLY file +def load_ply_normal(filename, point_num): + plydata = PlyData.read(filename) + pc = plydata['normal'].data[:point_num] + pc_array = np.array([[x, y, z] for x,y,z in pc]) + return pc_array + +# Make up rows for Nxk array +# Input Pad is 'edge' or 'constant' +def pad_arr_rows(arr, row, pad='edge'): + assert(len(arr.shape) == 2) + assert(arr.shape[0] <= row) + assert(pad == 'edge' or pad == 'constant') + if arr.shape[0] == row: + return arr + if pad == 'edge': + return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'edge') + if pad == 'constant': + return np.lib.pad(arr, ((0, row-arr.shape[0]), (0, 0)), 'constant', (0, 0)) + + diff --git a/utils/eulerangles.py b/utils/eulerangles.py new file mode 100644 index 0000000..87bd605 --- /dev/null +++ b/utils/eulerangles.py @@ -0,0 +1,418 @@ +# emacs: -*- mode: python-mode; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +# +# See COPYING file distributed along with the NiBabel package for the +# copyright and license terms. +# +### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## +''' Module implementing Euler angle rotations and their conversions + +See: + +* http://en.wikipedia.org/wiki/Rotation_matrix +* http://en.wikipedia.org/wiki/Euler_angles +* http://mathworld.wolfram.com/EulerAngles.html + +See also: *Representing Attitude with Euler Angles and Quaternions: A +Reference* (2006) by James Diebel. A cached PDF link last found here: + +http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.110.5134 + +Euler's rotation theorem tells us that any rotation in 3D can be +described by 3 angles. Let's call the 3 angles the *Euler angle vector* +and call the angles in the vector :math:`alpha`, :math:`beta` and +:math:`gamma`. The vector is [ :math:`alpha`, +:math:`beta`. :math:`gamma` ] and, in this description, the order of the +parameters specifies the order in which the rotations occur (so the +rotation corresponding to :math:`alpha` is applied first). + +In order to specify the meaning of an *Euler angle vector* we need to +specify the axes around which each of the rotations corresponding to +:math:`alpha`, :math:`beta` and :math:`gamma` will occur. + +There are therefore three axes for the rotations :math:`alpha`, +:math:`beta` and :math:`gamma`; let's call them :math:`i` :math:`j`, +:math:`k`. + +Let us express the rotation :math:`alpha` around axis `i` as a 3 by 3 +rotation matrix `A`. Similarly :math:`beta` around `j` becomes 3 x 3 +matrix `B` and :math:`gamma` around `k` becomes matrix `G`. Then the +whole rotation expressed by the Euler angle vector [ :math:`alpha`, +:math:`beta`. :math:`gamma` ], `R` is given by:: + + R = np.dot(G, np.dot(B, A)) + +See http://mathworld.wolfram.com/EulerAngles.html + +The order :math:`G B A` expresses the fact that the rotations are +performed in the order of the vector (:math:`alpha` around axis `i` = +`A` first). + +To convert a given Euler angle vector to a meaningful rotation, and a +rotation matrix, we need to define: + +* the axes `i`, `j`, `k` +* whether a rotation matrix should be applied on the left of a vector to + be transformed (vectors are column vectors) or on the right (vectors + are row vectors). +* whether the rotations move the axes as they are applied (intrinsic + rotations) - compared the situation where the axes stay fixed and the + vectors move within the axis frame (extrinsic) +* the handedness of the coordinate system + +See: http://en.wikipedia.org/wiki/Rotation_matrix#Ambiguities + +We are using the following conventions: + +* axes `i`, `j`, `k` are the `z`, `y`, and `x` axes respectively. Thus + an Euler angle vector [ :math:`alpha`, :math:`beta`. :math:`gamma` ] + in our convention implies a :math:`alpha` radian rotation around the + `z` axis, followed by a :math:`beta` rotation around the `y` axis, + followed by a :math:`gamma` rotation around the `x` axis. +* the rotation matrix applies on the left, to column vectors on the + right, so if `R` is the rotation matrix, and `v` is a 3 x N matrix + with N column vectors, the transformed vector set `vdash` is given by + ``vdash = np.dot(R, v)``. +* extrinsic rotations - the axes are fixed, and do not move with the + rotations. +* a right-handed coordinate system + +The convention of rotation around ``z``, followed by rotation around +``y``, followed by rotation around ``x``, is known (confusingly) as +"xyz", pitch-roll-yaw, Cardan angles, or Tait-Bryan angles. +''' + +import math + +import sys +if sys.version_info >= (3,0): + from functools import reduce + +import numpy as np + + +_FLOAT_EPS_4 = np.finfo(float).eps * 4.0 + + +def euler2mat(z=0, y=0, x=0): + ''' Return matrix for rotations around z, y and x axes + + Uses the z, then y, then x convention above + + Parameters + ---------- + z : scalar + Rotation angle in radians around z-axis (performed first) + y : scalar + Rotation angle in radians around y-axis + x : scalar + Rotation angle in radians around x-axis (performed last) + + Returns + ------- + M : array shape (3,3) + Rotation matrix giving same rotation as for given angles + + Examples + -------- + >>> zrot = 1.3 # radians + >>> yrot = -0.1 + >>> xrot = 0.2 + >>> M = euler2mat(zrot, yrot, xrot) + >>> M.shape == (3, 3) + True + + The output rotation matrix is equal to the composition of the + individual rotations + + >>> M1 = euler2mat(zrot) + >>> M2 = euler2mat(0, yrot) + >>> M3 = euler2mat(0, 0, xrot) + >>> composed_M = np.dot(M3, np.dot(M2, M1)) + >>> np.allclose(M, composed_M) + True + + You can specify rotations by named arguments + + >>> np.all(M3 == euler2mat(x=xrot)) + True + + When applying M to a vector, the vector should column vector to the + right of M. If the right hand side is a 2D array rather than a + vector, then each column of the 2D array represents a vector. + + >>> vec = np.array([1, 0, 0]).reshape((3,1)) + >>> v2 = np.dot(M, vec) + >>> vecs = np.array([[1, 0, 0],[0, 1, 0]]).T # giving 3x2 array + >>> vecs2 = np.dot(M, vecs) + + Rotations are counter-clockwise. + + >>> zred = np.dot(euler2mat(z=np.pi/2), np.eye(3)) + >>> np.allclose(zred, [[0, -1, 0],[1, 0, 0], [0, 0, 1]]) + True + >>> yred = np.dot(euler2mat(y=np.pi/2), np.eye(3)) + >>> np.allclose(yred, [[0, 0, 1],[0, 1, 0], [-1, 0, 0]]) + True + >>> xred = np.dot(euler2mat(x=np.pi/2), np.eye(3)) + >>> np.allclose(xred, [[1, 0, 0],[0, 0, -1], [0, 1, 0]]) + True + + Notes + ----- + The direction of rotation is given by the right-hand rule (orient + the thumb of the right hand along the axis around which the rotation + occurs, with the end of the thumb at the positive end of the axis; + curl your fingers; the direction your fingers curl is the direction + of rotation). Therefore, the rotations are counterclockwise if + looking along the axis of rotation from positive to negative. + ''' + Ms = [] + if z: + cosz = math.cos(z) + sinz = math.sin(z) + Ms.append(np.array( + [[cosz, -sinz, 0], + [sinz, cosz, 0], + [0, 0, 1]])) + if y: + cosy = math.cos(y) + siny = math.sin(y) + Ms.append(np.array( + [[cosy, 0, siny], + [0, 1, 0], + [-siny, 0, cosy]])) + if x: + cosx = math.cos(x) + sinx = math.sin(x) + Ms.append(np.array( + [[1, 0, 0], + [0, cosx, -sinx], + [0, sinx, cosx]])) + if Ms: + return reduce(np.dot, Ms[::-1]) + return np.eye(3) + + +def mat2euler(M, cy_thresh=None): + ''' Discover Euler angle vector from 3x3 matrix + + Uses the conventions above. + + Parameters + ---------- + M : array-like, shape (3,3) + cy_thresh : None or scalar, optional + threshold below which to give up on straightforward arctan for + estimating x rotation. If None (default), estimate from + precision of input. + + Returns + ------- + z : scalar + y : scalar + x : scalar + Rotations in radians around z, y, x axes, respectively + + Notes + ----- + If there was no numerical error, the routine could be derived using + Sympy expression for z then y then x rotation matrix, which is:: + + [ cos(y)*cos(z), -cos(y)*sin(z), sin(y)], + [cos(x)*sin(z) + cos(z)*sin(x)*sin(y), cos(x)*cos(z) - sin(x)*sin(y)*sin(z), -cos(y)*sin(x)], + [sin(x)*sin(z) - cos(x)*cos(z)*sin(y), cos(z)*sin(x) + cos(x)*sin(y)*sin(z), cos(x)*cos(y)] + + with the obvious derivations for z, y, and x + + z = atan2(-r12, r11) + y = asin(r13) + x = atan2(-r23, r33) + + Problems arise when cos(y) is close to zero, because both of:: + + z = atan2(cos(y)*sin(z), cos(y)*cos(z)) + x = atan2(cos(y)*sin(x), cos(x)*cos(y)) + + will be close to atan2(0, 0), and highly unstable. + + The ``cy`` fix for numerical instability below is from: *Graphics + Gems IV*, Paul Heckbert (editor), Academic Press, 1994, ISBN: + 0123361559. Specifically it comes from EulerAngles.c by Ken + Shoemake, and deals with the case where cos(y) is close to zero: + + See: http://www.graphicsgems.org/ + + The code appears to be licensed (from the website) as "can be used + without restrictions". + ''' + M = np.asarray(M) + if cy_thresh is None: + try: + cy_thresh = np.finfo(M.dtype).eps * 4 + except ValueError: + cy_thresh = _FLOAT_EPS_4 + r11, r12, r13, r21, r22, r23, r31, r32, r33 = M.flat + # cy: sqrt((cos(y)*cos(z))**2 + (cos(x)*cos(y))**2) + cy = math.sqrt(r33*r33 + r23*r23) + if cy > cy_thresh: # cos(y) not close to zero, standard form + z = math.atan2(-r12, r11) # atan2(cos(y)*sin(z), cos(y)*cos(z)) + y = math.atan2(r13, cy) # atan2(sin(y), cy) + x = math.atan2(-r23, r33) # atan2(cos(y)*sin(x), cos(x)*cos(y)) + else: # cos(y) (close to) zero, so x -> 0.0 (see above) + # so r21 -> sin(z), r22 -> cos(z) and + z = math.atan2(r21, r22) + y = math.atan2(r13, cy) # atan2(sin(y), cy) + x = 0.0 + return z, y, x + + +def euler2quat(z=0, y=0, x=0): + ''' Return quaternion corresponding to these Euler angles + + Uses the z, then y, then x convention above + + Parameters + ---------- + z : scalar + Rotation angle in radians around z-axis (performed first) + y : scalar + Rotation angle in radians around y-axis + x : scalar + Rotation angle in radians around x-axis (performed last) + + Returns + ------- + quat : array shape (4,) + Quaternion in w, x, y z (real, then vector) format + + Notes + ----- + We can derive this formula in Sympy using: + + 1. Formula giving quaternion corresponding to rotation of theta radians + about arbitrary axis: + http://mathworld.wolfram.com/EulerParameters.html + 2. Generated formulae from 1.) for quaternions corresponding to + theta radians rotations about ``x, y, z`` axes + 3. Apply quaternion multiplication formula - + http://en.wikipedia.org/wiki/Quaternions#Hamilton_product - to + formulae from 2.) to give formula for combined rotations. + ''' + z = z/2.0 + y = y/2.0 + x = x/2.0 + cz = math.cos(z) + sz = math.sin(z) + cy = math.cos(y) + sy = math.sin(y) + cx = math.cos(x) + sx = math.sin(x) + return np.array([ + cx*cy*cz - sx*sy*sz, + cx*sy*sz + cy*cz*sx, + cx*cz*sy - sx*cy*sz, + cx*cy*sz + sx*cz*sy]) + + +def quat2euler(q): + ''' Return Euler angles corresponding to quaternion `q` + + Parameters + ---------- + q : 4 element sequence + w, x, y, z of quaternion + + Returns + ------- + z : scalar + Rotation angle in radians around z-axis (performed first) + y : scalar + Rotation angle in radians around y-axis + x : scalar + Rotation angle in radians around x-axis (performed last) + + Notes + ----- + It's possible to reduce the amount of calculation a little, by + combining parts of the ``quat2mat`` and ``mat2euler`` functions, but + the reduction in computation is small, and the code repetition is + large. + ''' + # delayed import to avoid cyclic dependencies + import nibabel.quaternions as nq + return mat2euler(nq.quat2mat(q)) + + +def euler2angle_axis(z=0, y=0, x=0): + ''' Return angle, axis corresponding to these Euler angles + + Uses the z, then y, then x convention above + + Parameters + ---------- + z : scalar + Rotation angle in radians around z-axis (performed first) + y : scalar + Rotation angle in radians around y-axis + x : scalar + Rotation angle in radians around x-axis (performed last) + + Returns + ------- + theta : scalar + angle of rotation + vector : array shape (3,) + axis around which rotation occurs + + Examples + -------- + >>> theta, vec = euler2angle_axis(0, 1.5, 0) + >>> print(theta) + 1.5 + >>> np.allclose(vec, [0, 1, 0]) + True + ''' + # delayed import to avoid cyclic dependencies + import nibabel.quaternions as nq + return nq.quat2angle_axis(euler2quat(z, y, x)) + + +def angle_axis2euler(theta, vector, is_normalized=False): + ''' Convert angle, axis pair to Euler angles + + Parameters + ---------- + theta : scalar + angle of rotation + vector : 3 element sequence + vector specifying axis for rotation. + is_normalized : bool, optional + True if vector is already normalized (has norm of 1). Default + False + + Returns + ------- + z : scalar + y : scalar + x : scalar + Rotations in radians around z, y, x axes, respectively + + Examples + -------- + >>> z, y, x = angle_axis2euler(0, [1, 0, 0]) + >>> np.allclose((z, y, x), 0) + True + + Notes + ----- + It's possible to reduce the amount of calculation a little, by + combining parts of the ``angle_axis2mat`` and ``mat2euler`` + functions, but the reduction in computation is small, and the code + repetition is large. + ''' + # delayed import to avoid cyclic dependencies + import nibabel.quaternions as nq + M = nq.angle_axis2mat(theta, vector, is_normalized) + return mat2euler(M) diff --git a/utils/pc_util.py b/utils/pc_util.py new file mode 100644 index 0000000..4913231 --- /dev/null +++ b/utils/pc_util.py @@ -0,0 +1,198 @@ +""" Utility functions for processing point clouds. + +Author: Charles R. Qi, Hao Su +Date: November 2016 +""" + +import os +import sys +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(BASE_DIR) + +# Draw point cloud +from eulerangles import euler2mat + +# Point cloud IO +import numpy as np +from plyfile import PlyData, PlyElement + + +# ---------------------------------------- +# Point Cloud/Volume Conversions +# ---------------------------------------- + +def point_cloud_to_volume_batch(point_clouds, vsize=12, radius=1.0, flatten=True): + """ Input is BxNx3 batch of point cloud + Output is Bx(vsize^3) + """ + vol_list = [] + for b in range(point_clouds.shape[0]): + vol = point_cloud_to_volume(np.squeeze(point_clouds[b,:,:]), vsize, radius) + if flatten: + vol_list.append(vol.flatten()) + else: + vol_list.append(np.expand_dims(np.expand_dims(vol, -1), 0)) + if flatten: + return np.vstack(vol_list) + else: + return np.concatenate(vol_list, 0) + + +def point_cloud_to_volume(points, vsize, radius=1.0): + """ input is Nx3 points. + output is vsize*vsize*vsize + assumes points are in range [-radius, radius] + """ + vol = np.zeros((vsize,vsize,vsize)) + voxel = 2*radius/float(vsize) + locations = (points + radius)/voxel + locations = locations.astype(int) + vol[locations[:,0],locations[:,1],locations[:,2]] = 1.0 + return vol + +#a = np.zeros((16,1024,3)) +#print point_cloud_to_volume_batch(a, 12, 1.0, False).shape + +def volume_to_point_cloud(vol): + """ vol is occupancy grid (value = 0 or 1) of size vsize*vsize*vsize + return Nx3 numpy array. + """ + vsize = vol.shape[0] + assert(vol.shape[1] == vsize and vol.shape[1] == vsize) + points = [] + for a in range(vsize): + for b in range(vsize): + for c in range(vsize): + if vol[a,b,c] == 1: + points.append(np.array([a,b,c])) + if len(points) == 0: + return np.zeros((0,3)) + points = np.vstack(points) + return points + +# ---------------------------------------- +# Point cloud IO +# ---------------------------------------- + +def read_ply(filename): + """ read XYZ point cloud from filename PLY file """ + plydata = PlyData.read(filename) + pc = plydata['vertex'].data + pc_array = np.array([[x, y, z] for x,y,z in pc]) + return pc_array + + +def write_ply(points, filename, text=True): + """ input: Nx3, write points to filename as PLY format. """ + points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] + vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) + el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) + PlyData([el], text=text).write(filename) + + +# ---------------------------------------- +# Simple Point cloud and Volume Renderers +# ---------------------------------------- + +def draw_point_cloud(input_points, canvasSize=500, space=200, diameter=25, + xrot=0, yrot=0, zrot=0, switch_xyz=[0,1,2], normalize=True): + """ Render point cloud to image with alpha channel. + Input: + points: Nx3 numpy array (+y is up direction) + Output: + gray image as numpy array of size canvasSizexcanvasSize + """ + image = np.zeros((canvasSize, canvasSize)) + if input_points is None or input_points.shape[0] == 0: + return image + + points = input_points[:, switch_xyz] + M = euler2mat(zrot, yrot, xrot) + points = (np.dot(M, points.transpose())).transpose() + + # Normalize the point cloud + # We normalize scale to fit points in a unit sphere + if normalize: + centroid = np.mean(points, axis=0) + points -= centroid + furthest_distance = np.max(np.sqrt(np.sum(abs(points)**2,axis=-1))) + points /= furthest_distance + + # Pre-compute the Gaussian disk + radius = (diameter-1)/2.0 + disk = np.zeros((diameter, diameter)) + for i in range(diameter): + for j in range(diameter): + if (i - radius) * (i-radius) + (j-radius) * (j-radius) <= radius * radius: + disk[i, j] = np.exp((-(i-radius)**2 - (j-radius)**2)/(radius**2)) + mask = np.argwhere(disk > 0) + dx = mask[:, 0] + dy = mask[:, 1] + dv = disk[disk > 0] + + # Order points by z-buffer + zorder = np.argsort(points[:, 2]) + points = points[zorder, :] + points[:, 2] = (points[:, 2] - np.min(points[:, 2])) / (np.max(points[:, 2] - np.min(points[:, 2]))) + max_depth = np.max(points[:, 2]) + + for i in range(points.shape[0]): + j = points.shape[0] - i - 1 + x = points[j, 0] + y = points[j, 1] + xc = canvasSize/2 + (x*space) + yc = canvasSize/2 + (y*space) + xc = int(np.round(xc)) + yc = int(np.round(yc)) + + px = dx + xc + py = dy + yc + + image[px, py] = image[px, py] * 0.7 + dv * (max_depth - points[j, 2]) * 0.3 + + image = image / np.max(image) + return image + +def point_cloud_three_views(points): + """ input points Nx3 numpy array (+y is up direction). + return an numpy array gray image of size 500x1500. """ + # +y is up direction + # xrot is azimuth + # yrot is in-plane + # zrot is elevation + img1 = draw_point_cloud(points, zrot=110/180.0*np.pi, xrot=45/180.0*np.pi, yrot=0/180.0*np.pi) + img2 = draw_point_cloud(points, zrot=70/180.0*np.pi, xrot=135/180.0*np.pi, yrot=0/180.0*np.pi) + img3 = draw_point_cloud(points, zrot=180.0/180.0*np.pi, xrot=90/180.0*np.pi, yrot=0/180.0*np.pi) + image_large = np.concatenate([img1, img2, img3], 1) + return image_large + + +from PIL import Image +def point_cloud_three_views_demo(): + """ Demo for draw_point_cloud function """ + points = read_ply('../third_party/mesh_sampling/piano.ply') + im_array = point_cloud_three_views(points) + img = Image.fromarray(np.uint8(im_array*255.0)) + img.save('piano.jpg') + +if __name__=="__main__": + point_cloud_three_views_demo() + + +import matplotlib.pyplot as plt +def pyplot_draw_point_cloud(points, output_filename): + """ points is a Nx3 numpy array """ + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(points[:,0], points[:,1], points[:,2]) + ax.set_xlabel('x') + ax.set_ylabel('y') + ax.set_zlabel('z') + #savefig(output_filename) + +def pyplot_draw_volume(vol, output_filename): + """ vol is of size vsize*vsize*vsize + output an image to output_filename + """ + points = volume_to_point_cloud(vol) + pyplot_draw_point_cloud(points, output_filename) diff --git a/utils/plyfile.py b/utils/plyfile.py new file mode 100644 index 0000000..69c2aa9 --- /dev/null +++ b/utils/plyfile.py @@ -0,0 +1,916 @@ +# Copyright 2014 Darsh Ranjan +# +# This file is part of python-plyfile. +# +# python-plyfile is free software: you can redistribute it and/or +# modify it under the terms of the GNU General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# python-plyfile is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with python-plyfile. If not, see +# . + +from itertools import islice as _islice + +import numpy as _np +from sys import byteorder as _byteorder + + +try: + _range = xrange +except NameError: + _range = range + + +# Many-many relation +_data_type_relation = [ + ('int8', 'i1'), + ('char', 'i1'), + ('uint8', 'u1'), + ('uchar', 'b1'), + ('uchar', 'u1'), + ('int16', 'i2'), + ('short', 'i2'), + ('uint16', 'u2'), + ('ushort', 'u2'), + ('int32', 'i4'), + ('int', 'i4'), + ('uint32', 'u4'), + ('uint', 'u4'), + ('float32', 'f4'), + ('float', 'f4'), + ('float64', 'f8'), + ('double', 'f8') +] + +_data_types = dict(_data_type_relation) +_data_type_reverse = dict((b, a) for (a, b) in _data_type_relation) + +_types_list = [] +_types_set = set() +for (_a, _b) in _data_type_relation: + if _a not in _types_set: + _types_list.append(_a) + _types_set.add(_a) + if _b not in _types_set: + _types_list.append(_b) + _types_set.add(_b) + + +_byte_order_map = { + 'ascii': '=', + 'binary_little_endian': '<', + 'binary_big_endian': '>' +} + +_byte_order_reverse = { + '<': 'binary_little_endian', + '>': 'binary_big_endian' +} + +_native_byte_order = {'little': '<', 'big': '>'}[_byteorder] + + +def _lookup_type(type_str): + if type_str not in _data_type_reverse: + try: + type_str = _data_types[type_str] + except KeyError: + raise ValueError("field type %r not in %r" % + (type_str, _types_list)) + + return _data_type_reverse[type_str] + + +def _split_line(line, n): + fields = line.split(None, n) + if len(fields) == n: + fields.append('') + + assert len(fields) == n + 1 + + return fields + + +def make2d(array, cols=None, dtype=None): + ''' + Make a 2D array from an array of arrays. The `cols' and `dtype' + arguments can be omitted if the array is not empty. + + ''' + if (cols is None or dtype is None) and not len(array): + raise RuntimeError("cols and dtype must be specified for empty " + "array") + + if cols is None: + cols = len(array[0]) + + if dtype is None: + dtype = array[0].dtype + + return _np.fromiter(array, [('_', dtype, (cols,))], + count=len(array))['_'] + + +class PlyParseError(Exception): + + ''' + Raised when a PLY file cannot be parsed. + + The attributes `element', `row', `property', and `message' give + additional information. + + ''' + + def __init__(self, message, element=None, row=None, prop=None): + self.message = message + self.element = element + self.row = row + self.prop = prop + + s = '' + if self.element: + s += 'element %r: ' % self.element.name + if self.row is not None: + s += 'row %d: ' % self.row + if self.prop: + s += 'property %r: ' % self.prop.name + s += self.message + + Exception.__init__(self, s) + + def __repr__(self): + return ('PlyParseError(%r, element=%r, row=%r, prop=%r)' % + self.message, self.element, self.row, self.prop) + + +class PlyData(object): + + ''' + PLY file header and data. + + A PlyData instance is created in one of two ways: by the static + method PlyData.read (to read a PLY file), or directly from __init__ + given a sequence of elements (which can then be written to a PLY + file). + + ''' + + def __init__(self, elements=[], text=False, byte_order='=', + comments=[], obj_info=[]): + ''' + elements: sequence of PlyElement instances. + + text: whether the resulting PLY file will be text (True) or + binary (False). + + byte_order: '<' for little-endian, '>' for big-endian, or '=' + for native. This is only relevant if `text' is False. + + comments: sequence of strings that will be placed in the header + between the 'ply' and 'format ...' lines. + + obj_info: like comments, but will be placed in the header with + "obj_info ..." instead of "comment ...". + + ''' + if byte_order == '=' and not text: + byte_order = _native_byte_order + + self.byte_order = byte_order + self.text = text + + self.comments = list(comments) + self.obj_info = list(obj_info) + self.elements = elements + + def _get_elements(self): + return self._elements + + def _set_elements(self, elements): + self._elements = tuple(elements) + self._index() + + elements = property(_get_elements, _set_elements) + + def _get_byte_order(self): + return self._byte_order + + def _set_byte_order(self, byte_order): + if byte_order not in ['<', '>', '=']: + raise ValueError("byte order must be '<', '>', or '='") + + self._byte_order = byte_order + + byte_order = property(_get_byte_order, _set_byte_order) + + def _index(self): + self._element_lookup = dict((elt.name, elt) for elt in + self._elements) + if len(self._element_lookup) != len(self._elements): + raise ValueError("two elements with same name") + + @staticmethod + def _parse_header(stream): + ''' + Parse a PLY header from a readable file-like stream. + + ''' + lines = [] + comments = {'comment': [], 'obj_info': []} + while True: + line = stream.readline().decode('ascii').strip() + fields = _split_line(line, 1) + + if fields[0] == 'end_header': + break + + elif fields[0] in comments.keys(): + lines.append(fields) + else: + lines.append(line.split()) + + a = 0 + if lines[a] != ['ply']: + raise PlyParseError("expected 'ply'") + + a += 1 + while lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + if lines[a][0] != 'format': + raise PlyParseError("expected 'format'") + + if lines[a][2] != '1.0': + raise PlyParseError("expected version '1.0'") + + if len(lines[a]) != 3: + raise PlyParseError("too many fields after 'format'") + + fmt = lines[a][1] + + if fmt not in _byte_order_map: + raise PlyParseError("don't understand format %r" % fmt) + + byte_order = _byte_order_map[fmt] + text = fmt == 'ascii' + + a += 1 + while a < len(lines) and lines[a][0] in comments.keys(): + comments[lines[a][0]].append(lines[a][1]) + a += 1 + + return PlyData(PlyElement._parse_multi(lines[a:]), + text, byte_order, + comments['comment'], comments['obj_info']) + + @staticmethod + def read(stream): + ''' + Read PLY data from a readable file-like object or filename. + + ''' + (must_close, stream) = _open_stream(stream, 'read') + try: + data = PlyData._parse_header(stream) + for elt in data: + elt._read(stream, data.text, data.byte_order) + finally: + if must_close: + stream.close() + + return data + + def write(self, stream): + ''' + Write PLY data to a writeable file-like object or filename. + + ''' + (must_close, stream) = _open_stream(stream, 'write') + try: + stream.write(self.header.encode('ascii')) + stream.write(b'\r\n') + for elt in self: + elt._write(stream, self.text, self.byte_order) + finally: + if must_close: + stream.close() + + @property + def header(self): + ''' + Provide PLY-formatted metadata for the instance. + + ''' + lines = ['ply'] + + if self.text: + lines.append('format ascii 1.0') + else: + lines.append('format ' + + _byte_order_reverse[self.byte_order] + + ' 1.0') + + # Some information is lost here, since all comments are placed + # between the 'format' line and the first element. + for c in self.comments: + lines.append('comment ' + c) + + for c in self.obj_info: + lines.append('obj_info ' + c) + + lines.extend(elt.header for elt in self.elements) + lines.append('end_header') + return '\r\n'.join(lines) + + def __iter__(self): + return iter(self.elements) + + def __len__(self): + return len(self.elements) + + def __contains__(self, name): + return name in self._element_lookup + + def __getitem__(self, name): + return self._element_lookup[name] + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyData(%r, text=%r, byte_order=%r, ' + 'comments=%r, obj_info=%r)' % + (self.elements, self.text, self.byte_order, + self.comments, self.obj_info)) + + +def _open_stream(stream, read_or_write): + if hasattr(stream, read_or_write): + return (False, stream) + try: + return (True, open(stream, read_or_write[0] + 'b')) + except TypeError: + raise RuntimeError("expected open file or filename") + + +class PlyElement(object): + + ''' + PLY file element. + + A client of this library doesn't normally need to instantiate this + directly, so the following is only for the sake of documenting the + internals. + + Creating a PlyElement instance is generally done in one of two ways: + as a byproduct of PlyData.read (when reading a PLY file) and by + PlyElement.describe (before writing a PLY file). + + ''' + + def __init__(self, name, properties, count, comments=[]): + ''' + This is not part of the public interface. The preferred methods + of obtaining PlyElement instances are PlyData.read (to read from + a file) and PlyElement.describe (to construct from a numpy + array). + + ''' + self._name = str(name) + self._check_name() + self._count = count + + self._properties = tuple(properties) + self._index() + + self.comments = list(comments) + + self._have_list = any(isinstance(p, PlyListProperty) + for p in self.properties) + + @property + def count(self): + return self._count + + def _get_data(self): + return self._data + + def _set_data(self, data): + self._data = data + self._count = len(data) + self._check_sanity() + + data = property(_get_data, _set_data) + + def _check_sanity(self): + for prop in self.properties: + if prop.name not in self._data.dtype.fields: + raise ValueError("dangling property %r" % prop.name) + + def _get_properties(self): + return self._properties + + def _set_properties(self, properties): + self._properties = tuple(properties) + self._check_sanity() + self._index() + + properties = property(_get_properties, _set_properties) + + def _index(self): + self._property_lookup = dict((prop.name, prop) + for prop in self._properties) + if len(self._property_lookup) != len(self._properties): + raise ValueError("two properties with same name") + + def ply_property(self, name): + return self._property_lookup[name] + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "element name %r contains spaces" % self._name + raise ValueError(msg) + + def dtype(self, byte_order='='): + ''' + Return the numpy dtype of the in-memory representation of the + data. (If there are no list properties, and the PLY format is + binary, then this also accurately describes the on-disk + representation of the element.) + + ''' + return [(prop.name, prop.dtype(byte_order)) + for prop in self.properties] + + @staticmethod + def _parse_multi(header_lines): + ''' + Parse a list of PLY element definitions. + + ''' + elements = [] + while header_lines: + (elt, header_lines) = PlyElement._parse_one(header_lines) + elements.append(elt) + + return elements + + @staticmethod + def _parse_one(lines): + ''' + Consume one element definition. The unconsumed input is + returned along with a PlyElement instance. + + ''' + a = 0 + line = lines[a] + + if line[0] != 'element': + raise PlyParseError("expected 'element'") + if len(line) > 3: + raise PlyParseError("too many fields after 'element'") + if len(line) < 3: + raise PlyParseError("too few fields after 'element'") + + (name, count) = (line[1], int(line[2])) + + comments = [] + properties = [] + while True: + a += 1 + if a >= len(lines): + break + + if lines[a][0] == 'comment': + comments.append(lines[a][1]) + elif lines[a][0] == 'property': + properties.append(PlyProperty._parse_one(lines[a])) + else: + break + + return (PlyElement(name, properties, count, comments), + lines[a:]) + + @staticmethod + def describe(data, name, len_types={}, val_types={}, + comments=[]): + ''' + Construct a PlyElement from an array's metadata. + + len_types and val_types can be given as mappings from list + property names to type strings (like 'u1', 'f4', etc., or + 'int8', 'float32', etc.). These can be used to define the length + and value types of list properties. List property lengths + always default to type 'u1' (8-bit unsigned integer), and value + types default to 'i4' (32-bit integer). + + ''' + if not isinstance(data, _np.ndarray): + raise TypeError("only numpy arrays are supported") + + if len(data.shape) != 1: + raise ValueError("only one-dimensional arrays are " + "supported") + + count = len(data) + + properties = [] + descr = data.dtype.descr + + for t in descr: + if not isinstance(t[1], str): + raise ValueError("nested records not supported") + + if not t[0]: + raise ValueError("field with empty name") + + if len(t) != 2 or t[1][1] == 'O': + # non-scalar field, which corresponds to a list + # property in PLY. + + if t[1][1] == 'O': + if len(t) != 2: + raise ValueError("non-scalar object fields not " + "supported") + + len_str = _data_type_reverse[len_types.get(t[0], 'u1')] + if t[1][1] == 'O': + val_type = val_types.get(t[0], 'i4') + val_str = _lookup_type(val_type) + else: + val_str = _lookup_type(t[1][1:]) + + prop = PlyListProperty(t[0], len_str, val_str) + else: + val_str = _lookup_type(t[1][1:]) + prop = PlyProperty(t[0], val_str) + + properties.append(prop) + + elt = PlyElement(name, properties, count, comments) + elt.data = data + + return elt + + def _read(self, stream, text, byte_order): + ''' + Read the actual data from a PLY file. + + ''' + if text: + self._read_txt(stream) + else: + if self._have_list: + # There are list properties, so a simple load is + # impossible. + self._read_bin(stream, byte_order) + else: + # There are no list properties, so loading the data is + # much more straightforward. + self._data = _np.fromfile(stream, + self.dtype(byte_order), + self.count) + + if len(self._data) < self.count: + k = len(self._data) + del self._data + raise PlyParseError("early end-of-file", self, k) + + self._check_sanity() + + def _write(self, stream, text, byte_order): + ''' + Write the data to a PLY file. + + ''' + if text: + self._write_txt(stream) + else: + if self._have_list: + # There are list properties, so serialization is + # slightly complicated. + self._write_bin(stream, byte_order) + else: + # no list properties, so serialization is + # straightforward. + self.data.astype(self.dtype(byte_order), + copy=False).tofile(stream) + + def _read_txt(self, stream): + ''' + Load a PLY element from an ASCII-format PLY file. The element + may contain list properties. + + ''' + self._data = _np.empty(self.count, dtype=self.dtype()) + + k = 0 + for line in _islice(iter(stream.readline, b''), self.count): + fields = iter(line.strip().split()) + for prop in self.properties: + try: + self._data[prop.name][k] = prop._from_fields(fields) + except StopIteration: + raise PlyParseError("early end-of-line", + self, k, prop) + except ValueError: + raise PlyParseError("malformed input", + self, k, prop) + try: + next(fields) + except StopIteration: + pass + else: + raise PlyParseError("expected end-of-line", self, k) + k += 1 + + if k < self.count: + del self._data + raise PlyParseError("early end-of-file", self, k) + + def _write_txt(self, stream): + ''' + Save a PLY element to an ASCII-format PLY file. The element may + contain list properties. + + ''' + for rec in self.data: + fields = [] + for prop in self.properties: + fields.extend(prop._to_fields(rec[prop.name])) + + _np.savetxt(stream, [fields], '%.18g', newline='\r\n') + + def _read_bin(self, stream, byte_order): + ''' + Load a PLY element from a binary PLY file. The element may + contain list properties. + + ''' + self._data = _np.empty(self.count, dtype=self.dtype(byte_order)) + + for k in _range(self.count): + for prop in self.properties: + try: + self._data[prop.name][k] = \ + prop._read_bin(stream, byte_order) + except StopIteration: + raise PlyParseError("early end-of-file", + self, k, prop) + + def _write_bin(self, stream, byte_order): + ''' + Save a PLY element to a binary PLY file. The element may + contain list properties. + + ''' + for rec in self.data: + for prop in self.properties: + prop._write_bin(rec[prop.name], stream, byte_order) + + @property + def header(self): + ''' + Format this element's metadata as it would appear in a PLY + header. + + ''' + lines = ['element %s %d' % (self.name, self.count)] + + # Some information is lost here, since all comments are placed + # between the 'element' line and the first property definition. + for c in self.comments: + lines.append('comment ' + c) + + lines.extend(list(map(str, self.properties))) + + return '\r\n'.join(lines) + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def __str__(self): + return self.header + + def __repr__(self): + return ('PlyElement(%r, %r, count=%d, comments=%r)' % + (self.name, self.properties, self.count, + self.comments)) + + +class PlyProperty(object): + + ''' + PLY property description. This class is pure metadata; the data + itself is contained in PlyElement instances. + + ''' + + def __init__(self, name, val_dtype): + self._name = str(name) + self._check_name() + self.val_dtype = val_dtype + + def _get_val_dtype(self): + return self._val_dtype + + def _set_val_dtype(self, val_dtype): + self._val_dtype = _data_types[_lookup_type(val_dtype)] + + val_dtype = property(_get_val_dtype, _set_val_dtype) + + @property + def name(self): + return self._name + + def _check_name(self): + if any(c.isspace() for c in self._name): + msg = "Error: property name %r contains spaces" % self._name + raise RuntimeError(msg) + + @staticmethod + def _parse_one(line): + assert line[0] == 'property' + + if line[1] == 'list': + if len(line) > 5: + raise PlyParseError("too many fields after " + "'property list'") + if len(line) < 5: + raise PlyParseError("too few fields after " + "'property list'") + + return PlyListProperty(line[4], line[2], line[3]) + + else: + if len(line) > 3: + raise PlyParseError("too many fields after " + "'property'") + if len(line) < 3: + raise PlyParseError("too few fields after " + "'property'") + + return PlyProperty(line[2], line[1]) + + def dtype(self, byte_order='='): + ''' + Return the numpy dtype description for this property (as a tuple + of strings). + + ''' + return byte_order + self.val_dtype + + def _from_fields(self, fields): + ''' + Parse from generator. Raise StopIteration if the property could + not be read. + + ''' + return _np.dtype(self.dtype()).type(next(fields)) + + def _to_fields(self, data): + ''' + Return generator over one item. + + ''' + yield _np.dtype(self.dtype()).type(data) + + def _read_bin(self, stream, byte_order): + ''' + Read data from a binary stream. Raise StopIteration if the + property could not be read. + + ''' + try: + return _np.fromfile(stream, self.dtype(byte_order), 1)[0] + except IndexError: + raise StopIteration + + def _write_bin(self, data, stream, byte_order): + ''' + Write data to a binary stream. + + ''' + _np.dtype(self.dtype(byte_order)).type(data).tofile(stream) + + def __str__(self): + val_str = _data_type_reverse[self.val_dtype] + return 'property %s %s' % (val_str, self.name) + + def __repr__(self): + return 'PlyProperty(%r, %r)' % (self.name, + _lookup_type(self.val_dtype)) + + +class PlyListProperty(PlyProperty): + + ''' + PLY list property description. + + ''' + + def __init__(self, name, len_dtype, val_dtype): + PlyProperty.__init__(self, name, val_dtype) + + self.len_dtype = len_dtype + + def _get_len_dtype(self): + return self._len_dtype + + def _set_len_dtype(self, len_dtype): + self._len_dtype = _data_types[_lookup_type(len_dtype)] + + len_dtype = property(_get_len_dtype, _set_len_dtype) + + def dtype(self, byte_order='='): + ''' + List properties always have a numpy dtype of "object". + + ''' + return '|O' + + def list_dtype(self, byte_order='='): + ''' + Return the pair (len_dtype, val_dtype) (both numpy-friendly + strings). + + ''' + return (byte_order + self.len_dtype, + byte_order + self.val_dtype) + + def _from_fields(self, fields): + (len_t, val_t) = self.list_dtype() + + n = int(_np.dtype(len_t).type(next(fields))) + + data = _np.loadtxt(list(_islice(fields, n)), val_t, ndmin=1) + if len(data) < n: + raise StopIteration + + return data + + def _to_fields(self, data): + ''' + Return generator over the (numerical) PLY representation of the + list data (length followed by actual data). + + ''' + (len_t, val_t) = self.list_dtype() + + data = _np.asarray(data, dtype=val_t).ravel() + + yield _np.dtype(len_t).type(data.size) + for x in data: + yield x + + def _read_bin(self, stream, byte_order): + (len_t, val_t) = self.list_dtype(byte_order) + + try: + n = _np.fromfile(stream, len_t, 1)[0] + except IndexError: + raise StopIteration + + data = _np.fromfile(stream, val_t, n) + if len(data) < n: + raise StopIteration + + return data + + def _write_bin(self, data, stream, byte_order): + ''' + Write data to a binary stream. + + ''' + (len_t, val_t) = self.list_dtype(byte_order) + + data = _np.asarray(data, dtype=val_t).ravel() + + _np.array(data.size, dtype=len_t).tofile(stream) + data.tofile(stream) + + def __str__(self): + len_str = _data_type_reverse[self.len_dtype] + val_str = _data_type_reverse[self.val_dtype] + return 'property list %s %s %s' % (len_str, val_str, self.name) + + def __repr__(self): + return ('PlyListProperty(%r, %r, %r)' % + (self.name, + _lookup_type(self.len_dtype), + _lookup_type(self.val_dtype))) diff --git a/utils/tf_util.py b/utils/tf_util.py new file mode 100644 index 0000000..4f76eac --- /dev/null +++ b/utils/tf_util.py @@ -0,0 +1,639 @@ +""" Wrapper functions for TensorFlow layers. + +Author: Charles R. Qi +Date: November 2016 +""" + +import numpy as np +import tensorflow as tf + +def _variable_on_cpu(name, shape, initializer, use_fp16=False): + """Helper to create a Variable stored on CPU memory. + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + dtype = tf.float16 if use_fp16 else tf.float32 + var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) + return var + +def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True): + """Helper to create an initialized Variable with weight decay. + + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + use_xavier: bool, whether to use xavier initializer + + Returns: + Variable Tensor + """ + if use_xavier: + initializer = tf.contrib.layers.xavier_initializer() + else: + initializer = tf.truncated_normal_initializer(stddev=stddev) + var = _variable_on_cpu(name, shape, initializer) + if wd is not None: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') + tf.add_to_collection('losses', weight_decay) + return var + + +def conv1d(inputs, + num_output_channels, + kernel_size, + scope, + stride=1, + padding='SAME', + use_xavier=True, + stddev=1e-3, + weight_decay=0.0, + activation_fn=tf.nn.relu, + bn=False, + bn_decay=None, + is_training=None): + """ 1D convolution with non-linear operation. + + Args: + inputs: 3-D tensor variable BxLxC + num_output_channels: int + kernel_size: int + scope: string + stride: int + padding: 'SAME' or 'VALID' + use_xavier: bool, use xavier_initializer if true + stddev: float, stddev for truncated_normal init + weight_decay: float + activation_fn: function + bn: bool, whether to use batch norm + bn_decay: float or float tensor variable in [0,1] + is_training: bool Tensor variable + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + num_in_channels = inputs.get_shape()[-1].value + kernel_shape = [kernel_size, + num_in_channels, num_output_channels] + kernel = _variable_with_weight_decay('weights', + shape=kernel_shape, + use_xavier=use_xavier, + stddev=stddev, + wd=weight_decay) + outputs = tf.nn.conv1d(inputs, kernel, + stride=stride, + padding=padding) + biases = _variable_on_cpu('biases', [num_output_channels], + tf.constant_initializer(0.0)) + outputs = tf.nn.bias_add(outputs, biases) + + if bn: + outputs = batch_norm_for_conv1d(outputs, is_training, + bn_decay=bn_decay, scope='bn') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return outputs + + + + +def conv2d(inputs, + num_output_channels, + kernel_size, + scope, + stride=[1, 1], + padding='SAME', + use_xavier=True, + stddev=1e-3, + weight_decay=0.0, + activation_fn=tf.nn.relu, + bn=False, + bn_decay=None, + is_training=None): + """ 2D convolution with non-linear operation. + + Args: + inputs: 4-D tensor variable BxHxWxC + num_output_channels: int + kernel_size: a list of 2 ints + scope: string + stride: a list of 2 ints + padding: 'SAME' or 'VALID' + use_xavier: bool, use xavier_initializer if true + stddev: float, stddev for truncated_normal init + weight_decay: float + activation_fn: function + bn: bool, whether to use batch norm + bn_decay: float or float tensor variable in [0,1] + is_training: bool Tensor variable + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_h, kernel_w = kernel_size + num_in_channels = inputs.get_shape()[-1].value + kernel_shape = [kernel_h, kernel_w, + num_in_channels, num_output_channels] + kernel = _variable_with_weight_decay('weights', + shape=kernel_shape, + use_xavier=use_xavier, + stddev=stddev, + wd=weight_decay) + stride_h, stride_w = stride + outputs = tf.nn.conv2d(inputs, kernel, + [1, stride_h, stride_w, 1], + padding=padding) + biases = _variable_on_cpu('biases', [num_output_channels], + tf.constant_initializer(0.0)) + outputs = tf.nn.bias_add(outputs, biases) + + if bn: + outputs = batch_norm_for_conv2d(outputs, is_training, + bn_decay=bn_decay, scope='bn') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return outputs + + +def conv2d_transpose(inputs, + num_output_channels, + kernel_size, + scope, + stride=[1, 1], + padding='SAME', + use_xavier=True, + stddev=1e-3, + weight_decay=0.0, + activation_fn=tf.nn.relu, + bn=False, + bn_decay=None, + is_training=None): + """ 2D convolution transpose with non-linear operation. + + Args: + inputs: 4-D tensor variable BxHxWxC + num_output_channels: int + kernel_size: a list of 2 ints + scope: string + stride: a list of 2 ints + padding: 'SAME' or 'VALID' + use_xavier: bool, use xavier_initializer if true + stddev: float, stddev for truncated_normal init + weight_decay: float + activation_fn: function + bn: bool, whether to use batch norm + bn_decay: float or float tensor variable in [0,1] + is_training: bool Tensor variable + + Returns: + Variable tensor + + Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a + """ + with tf.variable_scope(scope) as sc: + kernel_h, kernel_w = kernel_size + num_in_channels = inputs.get_shape()[-1].value + kernel_shape = [kernel_h, kernel_w, + num_output_channels, num_in_channels] # reversed to conv2d + kernel = _variable_with_weight_decay('weights', + shape=kernel_shape, + use_xavier=use_xavier, + stddev=stddev, + wd=weight_decay) + stride_h, stride_w = stride + + # from slim.convolution2d_transpose + def get_deconv_dim(dim_size, stride_size, kernel_size, padding): + dim_size *= stride_size + + if padding == 'VALID' and dim_size is not None: + dim_size += max(kernel_size - stride_size, 0) + return dim_size + + # caculate output shape + batch_size = inputs.get_shape()[0].value + height = inputs.get_shape()[1].value + width = inputs.get_shape()[2].value + out_height = get_deconv_dim(height, stride_h, kernel_h, padding) + out_width = get_deconv_dim(width, stride_w, kernel_w, padding) + output_shape = [batch_size, out_height, out_width, num_output_channels] + + outputs = tf.nn.conv2d_transpose(inputs, kernel, output_shape, + [1, stride_h, stride_w, 1], + padding=padding) + biases = _variable_on_cpu('biases', [num_output_channels], + tf.constant_initializer(0.0)) + outputs = tf.nn.bias_add(outputs, biases) + + if bn: + outputs = batch_norm_for_conv2d(outputs, is_training, + bn_decay=bn_decay, scope='bn') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return outputs + + + +def conv3d(inputs, + num_output_channels, + kernel_size, + scope, + stride=[1, 1, 1], + padding='SAME', + use_xavier=True, + stddev=1e-3, + weight_decay=0.0, + activation_fn=tf.nn.relu, + bn=False, + bn_decay=None, + is_training=None): + """ 3D convolution with non-linear operation. + + Args: + inputs: 5-D tensor variable BxDxHxWxC + num_output_channels: int + kernel_size: a list of 3 ints + scope: string + stride: a list of 3 ints + padding: 'SAME' or 'VALID' + use_xavier: bool, use xavier_initializer if true + stddev: float, stddev for truncated_normal init + weight_decay: float + activation_fn: function + bn: bool, whether to use batch norm + bn_decay: float or float tensor variable in [0,1] + is_training: bool Tensor variable + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_d, kernel_h, kernel_w = kernel_size + num_in_channels = inputs.get_shape()[-1].value + kernel_shape = [kernel_d, kernel_h, kernel_w, + num_in_channels, num_output_channels] + kernel = _variable_with_weight_decay('weights', + shape=kernel_shape, + use_xavier=use_xavier, + stddev=stddev, + wd=weight_decay) + stride_d, stride_h, stride_w = stride + outputs = tf.nn.conv3d(inputs, kernel, + [1, stride_d, stride_h, stride_w, 1], + padding=padding) + biases = _variable_on_cpu('biases', [num_output_channels], + tf.constant_initializer(0.0)) + outputs = tf.nn.bias_add(outputs, biases) + + if bn: + outputs = batch_norm_for_conv3d(outputs, is_training, + bn_decay=bn_decay, scope='bn') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return outputs + +def fully_connected(inputs, + num_outputs, + scope, + use_xavier=True, + stddev=1e-3, + weight_decay=0.0, + activation_fn=tf.nn.relu, + bn=False, + bn_decay=None, + is_training=None): + """ Fully connected layer with non-linear operation. + + Args: + inputs: 2-D tensor BxN + num_outputs: int + + Returns: + Variable tensor of size B x num_outputs. + """ + with tf.variable_scope(scope) as sc: + num_input_units = inputs.get_shape()[-1].value + weights = _variable_with_weight_decay('weights', + shape=[num_input_units, num_outputs], + use_xavier=use_xavier, + stddev=stddev, + wd=weight_decay) + outputs = tf.matmul(inputs, weights) + biases = _variable_on_cpu('biases', [num_outputs], + tf.constant_initializer(0.0)) + outputs = tf.nn.bias_add(outputs, biases) + + if bn: + outputs = batch_norm_for_fc(outputs, is_training, bn_decay, 'bn') + + if activation_fn is not None: + outputs = activation_fn(outputs) + return outputs + + +def max_pool2d(inputs, + kernel_size, + scope, + stride=[2, 2], + padding='VALID'): + """ 2D max pooling. + + Args: + inputs: 4-D tensor BxHxWxC + kernel_size: a list of 2 ints + stride: a list of 2 ints + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_h, kernel_w = kernel_size + stride_h, stride_w = stride + outputs = tf.nn.max_pool(inputs, + ksize=[1, kernel_h, kernel_w, 1], + strides=[1, stride_h, stride_w, 1], + padding=padding, + name=sc.name) + return outputs + +def avg_pool2d(inputs, + kernel_size, + scope, + stride=[2, 2], + padding='VALID'): + """ 2D avg pooling. + + Args: + inputs: 4-D tensor BxHxWxC + kernel_size: a list of 2 ints + stride: a list of 2 ints + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_h, kernel_w = kernel_size + stride_h, stride_w = stride + outputs = tf.nn.avg_pool(inputs, + ksize=[1, kernel_h, kernel_w, 1], + strides=[1, stride_h, stride_w, 1], + padding=padding, + name=sc.name) + return outputs + + +def max_pool3d(inputs, + kernel_size, + scope, + stride=[2, 2, 2], + padding='VALID'): + """ 3D max pooling. + + Args: + inputs: 5-D tensor BxDxHxWxC + kernel_size: a list of 3 ints + stride: a list of 3 ints + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_d, kernel_h, kernel_w = kernel_size + stride_d, stride_h, stride_w = stride + outputs = tf.nn.max_pool3d(inputs, + ksize=[1, kernel_d, kernel_h, kernel_w, 1], + strides=[1, stride_d, stride_h, stride_w, 1], + padding=padding, + name=sc.name) + return outputs + +def avg_pool3d(inputs, + kernel_size, + scope, + stride=[2, 2, 2], + padding='VALID'): + """ 3D avg pooling. + + Args: + inputs: 5-D tensor BxDxHxWxC + kernel_size: a list of 3 ints + stride: a list of 3 ints + + Returns: + Variable tensor + """ + with tf.variable_scope(scope) as sc: + kernel_d, kernel_h, kernel_w = kernel_size + stride_d, stride_h, stride_w = stride + outputs = tf.nn.avg_pool3d(inputs, + ksize=[1, kernel_d, kernel_h, kernel_w, 1], + strides=[1, stride_d, stride_h, stride_w, 1], + padding=padding, + name=sc.name) + return outputs + + + + + +def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay): + """ Batch normalization on convolutional maps and beyond... + Ref.: http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow + + Args: + inputs: Tensor, k-D input ... x C could be BC or BHWC or BDHWC + is_training: boolean tf.Varialbe, true indicates training phase + scope: string, variable scope + moments_dims: a list of ints, indicating dimensions for moments calculation + bn_decay: float or float tensor variable, controling moving average weight + Return: + normed: batch-normalized maps + """ + with tf.variable_scope(scope) as sc: + num_channels = inputs.get_shape()[-1].value + beta = tf.Variable(tf.constant(0.0, shape=[num_channels]), + name='beta', trainable=True) + gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]), + name='gamma', trainable=True) + batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') + decay = bn_decay if bn_decay is not None else 0.9 + ema = tf.train.ExponentialMovingAverage(decay=decay) + # Operator that maintains moving averages of variables. + ema_apply_op = tf.cond(is_training, + lambda: ema.apply([batch_mean, batch_var]), + lambda: tf.no_op()) + + # Update moving average and return current batch's avg and var. + def mean_var_with_update(): + with tf.control_dependencies([ema_apply_op]): + return tf.identity(batch_mean), tf.identity(batch_var) + + # ema.average returns the Variable holding the average of var. + mean, var = tf.cond(is_training, + mean_var_with_update, + lambda: (ema.average(batch_mean), ema.average(batch_var))) + normed = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, 1e-3) + return normed + + +def batch_norm_for_fc(inputs, is_training, bn_decay, scope): + """ Batch normalization on FC data. + + Args: + inputs: Tensor, 2D BxC input + is_training: boolean tf.Varialbe, true indicates training phase + bn_decay: float or float tensor variable, controling moving average weight + scope: string, variable scope + Return: + normed: batch-normalized maps + """ + return batch_norm_template(inputs, is_training, scope, [0,], bn_decay) + + +def batch_norm_for_conv1d(inputs, is_training, bn_decay, scope): + """ Batch normalization on 1D convolutional maps. + + Args: + inputs: Tensor, 3D BLC input maps + is_training: boolean tf.Varialbe, true indicates training phase + bn_decay: float or float tensor variable, controling moving average weight + scope: string, variable scope + Return: + normed: batch-normalized maps + """ + return batch_norm_template(inputs, is_training, scope, [0,1], bn_decay) + + + + +def batch_norm_for_conv2d(inputs, is_training, bn_decay, scope): + """ Batch normalization on 2D convolutional maps. + + Args: + inputs: Tensor, 4D BHWC input maps + is_training: boolean tf.Varialbe, true indicates training phase + bn_decay: float or float tensor variable, controling moving average weight + scope: string, variable scope + Return: + normed: batch-normalized maps + """ + return batch_norm_template(inputs, is_training, scope, [0,1,2], bn_decay) + + + +def batch_norm_for_conv3d(inputs, is_training, bn_decay, scope): + """ Batch normalization on 3D convolutional maps. + + Args: + inputs: Tensor, 5D BDHWC input maps + is_training: boolean tf.Varialbe, true indicates training phase + bn_decay: float or float tensor variable, controling moving average weight + scope: string, variable scope + Return: + normed: batch-normalized maps + """ + return batch_norm_template(inputs, is_training, scope, [0,1,2,3], bn_decay) + + +def dropout(inputs, + is_training, + scope, + keep_prob=0.5, + noise_shape=None): + """ Dropout layer. + + Args: + inputs: tensor + is_training: boolean tf.Variable + scope: string + keep_prob: float in [0,1] + noise_shape: list of ints + + Returns: + tensor variable + """ + with tf.variable_scope(scope) as sc: + outputs = tf.cond(is_training, + lambda: tf.nn.dropout(inputs, keep_prob, noise_shape), + lambda: inputs) + return outputs + + +def pairwise_distance(point_cloud): + """Compute pairwise distance of a point cloud. + + Args: + point_cloud: tensor (batch_size, num_points, num_dims) + + Returns: + pairwise distance: (batch_size, num_points, num_points) + """ + point_cloud = tf.squeeze(point_cloud) + point_cloud_transpose = tf.transpose(point_cloud, perm=[0, 2, 1]) + point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose) + point_cloud_inner = -2*point_cloud_inner + point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keep_dims=True) + point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1]) + return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose + + +def knn(adj_matrix, k=20): + """Get KNN based on the pairwise distance. + Args: + pairwise distance: (batch_size, num_points, num_points) + k: int + + Returns: + nearest neighbors: (batch_size, num_points, k) + """ + neg_adj = -adj_matrix + _, nn_idx = tf.nn.top_k(neg_adj, k=k) + return nn_idx + + +def get_edge_feature(point_cloud, nn_idx, k=20): + """Construct edge feature for each point + Args: + point_cloud: (batch_size, num_points, 1, num_dims) + nn_idx: (batch_size, num_points, k) + k: int + + Returns: + edge features: (batch_size, num_points, k, num_dims) + """ + # point_feature_broad = tf.expand_dims(point_feature, axis=-2) + point_cloud = tf.squeeze(point_cloud) + point_cloud_central = point_cloud + + point_cloud_shape = point_cloud.get_shape() + batch_size = point_cloud_shape[0].value + num_points = point_cloud_shape[1].value + num_dims = point_cloud_shape[2].value + + idx_ = tf.range(batch_size) * num_points + idx_ = tf.reshape(idx_, [batch_size, 1, 1]) + + point_cloud_flat = tf.reshape(point_cloud, [-1, num_dims]) + point_cloud_neighbors = tf.gather(point_cloud_flat, nn_idx+idx_) + point_cloud_central = tf.expand_dims(point_cloud_central, axis=-2) + + point_cloud_central = tf.tile(point_cloud_central, [1, 1, k, 1]) + + edge_feature = tf.concat([point_cloud_central, point_cloud_neighbors-point_cloud_central], axis=-1) + return edge_feature