Skip to content

Commit

Permalink
add edge conv
Browse files Browse the repository at this point in the history
  • Loading branch information
WangYueFt committed Feb 3, 2018
1 parent 9c11fa8 commit 2fee5f8
Show file tree
Hide file tree
Showing 9 changed files with 2,935 additions and 0 deletions.
150 changes: 150 additions & 0 deletions models/dgcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import tensorflow as tf
import numpy as np
import math
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, '../utils'))
sys.path.append(os.path.join(BASE_DIR, '../../utils'))
import tf_util
from transform_nets import input_transform_net


def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32, shape=(batch_size, num_point, 3))
labels_pl = tf.placeholder(tf.int32, shape=(batch_size))
return pointclouds_pl, labels_pl


def get_model(point_cloud, is_training, bn_decay=None):
""" Classification PointNet, input is BxNx3, output Bx40 """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
end_points = {}
k = 20

adj_matrix = tf_util.pairwise_distance(point_cloud)
nn_idx = tf_util.knn(adj_matrix, k=k)
edge_feature = tf_util.get_edge_feature(point_cloud, nn_idx=nn_idx, k=k)

with tf.variable_scope('transform_net1') as sc:
transform = input_transform_net(edge_feature, is_training, bn_decay, K=3)

point_cloud_transformed = tf.matmul(point_cloud, transform)
adj_matrix = tf_util.pairwise_distance(point_cloud_transformed)
nn_idx = tf_util.knn(adj_matrix, k=k)
edge_feature = tf_util.get_edge_feature(point_cloud_transformed, nn_idx=nn_idx, k=k)

net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='dgcnn1', bn_decay=bn_decay)
net = tf.reduce_max(net, axis=-2, keep_dims=True)
net1 = net

adj_matrix = tf_util.pairwise_distance(net)
nn_idx = tf_util.knn(adj_matrix, k=k)
edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)

net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='dgcnn2', bn_decay=bn_decay)
net = tf.reduce_max(net, axis=-2, keep_dims=True)
net2 = net

adj_matrix = tf_util.pairwise_distance(net)
nn_idx = tf_util.knn(adj_matrix, k=k)
edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)

net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='dgcnn3', bn_decay=bn_decay)
net = tf.reduce_max(net, axis=-2, keep_dims=True)
net3 = net

adj_matrix = tf_util.pairwise_distance(net)
nn_idx = tf_util.knn(adj_matrix, k=k)
edge_feature = tf_util.get_edge_feature(net, nn_idx=nn_idx, k=k)

net = tf_util.conv2d(edge_feature, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='dgcnn4', bn_decay=bn_decay)
net = tf.reduce_max(net, axis=-2, keep_dims=True)
net4 = net

net = tf_util.conv2d(tf.concat([net1, net2, net3, net4], axis=-1), 1024, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='agg', bn_decay=bn_decay)

net = tf.reduce_max(net, axis=1, keep_dims=True)

# MLP on global point cloud vector
net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='fc1', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
scope='dp1')
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='fc2', bn_decay=bn_decay)
net = tf_util.dropout(net, keep_prob=0.5, is_training=is_training,
scope='dp2')
net = tf_util.fully_connected(net, 40, activation_fn=None, scope='fc3')

return net, end_points


def get_loss(pred, label, end_points):
""" pred: B*NUM_CLASSES,
label: B, """
labels = tf.one_hot(indices=label, depth=40)
loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=pred, label_smoothing=0.2)
classify_loss = tf.reduce_mean(loss)
return classify_loss


if __name__=='__main__':
batch_size = 2
num_pt = 124
pos_dim = 3

input_feed = np.random.rand(batch_size, num_pt, pos_dim)
label_feed = np.random.rand(batch_size)
label_feed[label_feed>=0.5] = 1
label_feed[label_feed<0.5] = 0
label_feed = label_feed.astype(np.int32)

# # np.save('./debug/input_feed.npy', input_feed)
# input_feed = np.load('./debug/input_feed.npy')
# print input_feed

with tf.Graph().as_default():
input_pl, label_pl = placeholder_inputs(batch_size, num_pt)
pos, ftr = get_model(input_pl, tf.constant(True))
# loss = get_loss(logits, label_pl, None)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
feed_dict = {input_pl: input_feed, label_pl: label_feed}
res1, res2 = sess.run([pos, ftr], feed_dict=feed_dict)
print res1.shape
print res1

print res2.shape
print res2












55 changes: 55 additions & 0 deletions models/transform_nets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import tensorflow as tf
import numpy as np
import sys
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, '../utils'))
import tf_util

def input_transform_net(edge_feature, is_training, bn_decay=None, K=3):
""" Input (XYZ) Transform Net, input is BxNx3 gray image
Return:
Transformation matrix of size 3xK """
batch_size = edge_feature.get_shape()[0].value
num_point = edge_feature.get_shape()[1].value

# input_image = tf.expand_dims(point_cloud, -1)
net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv2', bn_decay=bn_decay)

net = tf.reduce_max(net, axis=-2, keep_dims=True)

net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv3', bn_decay=bn_decay)
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='tmaxpool')

net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='tfc1', bn_decay=bn_decay)
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='tfc2', bn_decay=bn_decay)

with tf.variable_scope('transform_XYZ') as sc:
# assert(K==3)
weights = tf.get_variable('weights', [256, K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases = tf.get_variable('biases', [K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
transform = tf.matmul(net, weights)
transform = tf.nn.bias_add(transform, biases)

transform = tf.reshape(transform, [batch_size, K, K])
return transform
149 changes: 149 additions & 0 deletions provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import os
import sys
import numpy as np
import h5py
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)

# Download dataset for point cloud classification
DATA_DIR = os.path.join(BASE_DIR, 'data')
if not os.path.exists(DATA_DIR):
os.mkdir(DATA_DIR)
if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
zipfile = os.path.basename(www)
os.system('wget %s; unzip %s' % (www, zipfile))
os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
os.system('rm %s' % (zipfile))


def shuffle_data(data, labels):
""" Shuffle data and labels.
Input:
data: B,N,... numpy array
label: B,... numpy array
Return:
shuffled data, label and shuffle indices
"""
idx = np.arange(len(labels))
np.random.shuffle(idx)
return data[idx, ...], labels[idx], idx


def rotate_point_cloud(batch_data):
""" Randomly rotate the point clouds to augument the dataset
rotation is per shape based along up direction
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in xrange(batch_data.shape[0]):
rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data


def rotate_point_cloud_by_angle(batch_data, rotation_angle):
""" Rotate the point cloud along up direction with certain angle.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in xrange(batch_data.shape[0]):
#rotation_angle = np.random.uniform() * 2 * np.pi
cosval = np.cos(rotation_angle)
sinval = np.sin(rotation_angle)
rotation_matrix = np.array([[cosval, 0, sinval],
[0, 1, 0],
[-sinval, 0, cosval]])
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix)
return rotated_data


def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18):
""" Randomly perturb the point clouds by small rotations
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, rotated batch of point clouds
"""
rotated_data = np.zeros(batch_data.shape, dtype=np.float32)
for k in xrange(batch_data.shape[0]):
angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip)
Rx = np.array([[1,0,0],
[0,np.cos(angles[0]),-np.sin(angles[0])],
[0,np.sin(angles[0]),np.cos(angles[0])]])
Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])],
[0,1,0],
[-np.sin(angles[1]),0,np.cos(angles[1])]])
Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0],
[np.sin(angles[2]),np.cos(angles[2]),0],
[0,0,1]])
R = np.dot(Rz, np.dot(Ry,Rx))
shape_pc = batch_data[k, ...]
rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R)
return rotated_data


def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05):
""" Randomly jitter points. jittering is per point.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, jittered batch of point clouds
"""
B, N, C = batch_data.shape
assert(clip > 0)
jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip)
jittered_data += batch_data
return jittered_data

def shift_point_cloud(batch_data, shift_range=0.1):
""" Randomly shift point cloud. Shift is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, shifted batch of point clouds
"""
B, N, C = batch_data.shape
shifts = np.random.uniform(-shift_range, shift_range, (B,3))
for batch_index in range(B):
batch_data[batch_index,:,:] += shifts[batch_index,:]
return batch_data


def random_scale_point_cloud(batch_data, scale_low=0.8, scale_high=1.25):
""" Randomly scale the point cloud. Scale is per point cloud.
Input:
BxNx3 array, original batch of point clouds
Return:
BxNx3 array, scaled batch of point clouds
"""
B, N, C = batch_data.shape
scales = np.random.uniform(scale_low, scale_high, B)
for batch_index in range(B):
batch_data[batch_index,:,:] *= scales[batch_index]
return batch_data

def getDataFiles(list_filename):
return [line.rstrip() for line in open(list_filename)]

def load_h5(h5_filename):
f = h5py.File(h5_filename)
data = f['data'][:]
label = f['label'][:]
return (data, label)

def loadDataFile(filename):
return load_h5(filename)
Loading

0 comments on commit 2fee5f8

Please sign in to comment.