Skip to content

Commit

Permalink
Add part segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
syb7573330 committed Feb 4, 2018
1 parent 6350a61 commit 666d2b2
Show file tree
Hide file tree
Showing 8 changed files with 3,936 additions and 175 deletions.
85 changes: 43 additions & 42 deletions models/transform_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,49 +7,50 @@
sys.path.append(os.path.join(BASE_DIR, '../utils'))
import tf_util

def input_transform_net(edge_feature, is_training, bn_decay=None, K=3):
""" Input (XYZ) Transform Net, input is BxNx3 gray image
Return:
Transformation matrix of size 3xK """
batch_size = edge_feature.get_shape()[0].value
num_point = edge_feature.get_shape()[1].value
def input_transform_net(edge_feature, is_training, bn_decay=None, K=3, is_dist=False):
""" Input (XYZ) Transform Net, input is BxNx3 gray image
Return:
Transformation matrix of size 3xK """
batch_size = edge_feature.get_shape()[0].value
num_point = edge_feature.get_shape()[1].value

# input_image = tf.expand_dims(point_cloud, -1)
net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv1', bn_decay=bn_decay)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv2', bn_decay=bn_decay)
net = tf.reduce_max(net, axis=-2, keep_dims=True)
net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv3', bn_decay=bn_decay)
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='tmaxpool')
# input_image = tf.expand_dims(point_cloud, -1)
net = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv1', bn_decay=bn_decay, is_dist=is_dist)
net = tf_util.conv2d(net, 128, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv2', bn_decay=bn_decay, is_dist=is_dist)

net = tf.reduce_max(net, axis=-2, keep_dims=True)

net = tf_util.conv2d(net, 1024, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='tconv3', bn_decay=bn_decay, is_dist=is_dist)
net = tf_util.max_pool2d(net, [num_point,1],
padding='VALID', scope='tmaxpool')

net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='tfc1', bn_decay=bn_decay)
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='tfc2', bn_decay=bn_decay)
net = tf.reshape(net, [batch_size, -1])
net = tf_util.fully_connected(net, 512, bn=True, is_training=is_training,
scope='tfc1', bn_decay=bn_decay,is_dist=is_dist)
net = tf_util.fully_connected(net, 256, bn=True, is_training=is_training,
scope='tfc2', bn_decay=bn_decay,is_dist=is_dist)

with tf.variable_scope('transform_XYZ') as sc:
# assert(K==3)
weights = tf.get_variable('weights', [256, K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases = tf.get_variable('biases', [K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
transform = tf.matmul(net, weights)
transform = tf.nn.bias_add(transform, biases)
with tf.variable_scope('transform_XYZ') as sc:
# assert(K==3)
with tf.device('/cpu:0'):
weights = tf.get_variable('weights', [256, K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases = tf.get_variable('biases', [K*K],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
biases += tf.constant(np.eye(K).flatten(), dtype=tf.float32)
transform = tf.matmul(net, weights)
transform = tf.nn.bias_add(transform, biases)

transform = tf.reshape(transform, [batch_size, K, K])
return transform
transform = tf.reshape(transform, [batch_size, K, K])
return transform
11 changes: 11 additions & 0 deletions part_seg/download_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#!/bin/bash

# Download original ShapeNetPart dataset (around 1GB) ['PartAnnotation']
wget https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_v0.zip
unzip shapenetcore_partanno_v0.zip
rm shapenetcore_partanno_v0.zip

# Download HDF5 for ShapeNet Part segmentation (around 346MB) ['hdf5_data']
wget https://shapenet.cs.stanford.edu/media/shapenet_part_seg_hdf5_data.zip
unzip shapenet_part_seg_hdf5_data.zip
rm shapenet_part_seg_hdf5_data.zip
141 changes: 141 additions & 0 deletions part_seg/part_seg_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import tensorflow as tf
import numpy as np
import math
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(BASE_DIR))
sys.path.append(os.path.join(BASE_DIR, '../utils'))
sys.path.append(os.path.join(BASE_DIR, '../models'))
sys.path.append(os.path.join(BASE_DIR, '../'))
import tf_util
from transform_nets import input_transform_net

def get_model(point_cloud, input_label, is_training, cat_num, part_num, \
batch_size, num_point, weight_decay, bn_decay=None):

batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
input_image = tf.expand_dims(point_cloud, -1)

k = 30

adj = tf_util.pairwise_distance(point_cloud)
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k)

with tf.variable_scope('transform_net1') as sc:
transform = input_transform_net(edge_feature, is_training, bn_decay, K=3, is_dist=True)
point_cloud_transformed = tf.matmul(point_cloud, transform)
input_image = tf.expand_dims(point_cloud_transformed, -1)
adj = tf_util.pairwise_distance(point_cloud_transformed)
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k)

out1 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv1', bn_decay=bn_decay, is_dist=True)

out2 = tf_util.conv2d(out1, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv2', bn_decay=bn_decay, is_dist=True)

net_max_1 = tf.reduce_max(out2, axis=-2, keep_dims=True)
net_mean_1 = tf.reduce_mean(out2, axis=-2, keep_dims=True)

out3 = tf_util.conv2d(tf.concat([net_max_1, net_mean_1], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv3', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out3, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out3, nn_idx=nn_idx, k=k)

out4 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv4', bn_decay=bn_decay, is_dist=True)

net_max_2 = tf.reduce_max(out4, axis=-2, keep_dims=True)
net_mean_2 = tf.reduce_mean(out4, axis=-2, keep_dims=True)

out5 = tf_util.conv2d(tf.concat([net_max_2, net_mean_2], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv5', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out5, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out5, nn_idx=nn_idx, k=k)

out6 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv6', bn_decay=bn_decay, is_dist=True)

net_max_3 = tf.reduce_max(out6, axis=-2, keep_dims=True)
net_mean_3 = tf.reduce_mean(out6, axis=-2, keep_dims=True)

out7 = tf_util.conv2d(tf.concat([net_max_3, net_mean_3], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv7', bn_decay=bn_decay, is_dist=True)

out8 = tf_util.conv2d(tf.concat([out3, out5, out7], axis=-1), 1024, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv13', bn_decay=bn_decay, is_dist=True)

out_max = tf_util.max_pool2d(out8, [num_point, 1], padding='VALID', scope='maxpool')

one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num])
one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 128, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='one_hot_label_expand', bn_decay=bn_decay, is_dist=True)
out_max = tf.concat(axis=3, values=[out_max, one_hot_label_expand])
expand = tf.tile(out_max, [1, num_point, 1, 1])

concat = tf.concat(axis=3, values=[expand,
net_max_1,
net_mean_1,
out3,
net_max_2,
net_mean_2,
out5,
net_max_3,
net_mean_3,
out7,
out8])

net2 = tf_util.conv2d(concat, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay,
bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay, is_dist=True)
net2 = tf_util.dropout(net2, keep_prob=0.6, is_training=is_training, scope='seg/dp1')
net2 = tf_util.conv2d(net2, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay,
bn=True, is_training=is_training, scope='seg/conv2', weight_decay=weight_decay, is_dist=True)
net2 = tf_util.dropout(net2, keep_prob=0.6, is_training=is_training, scope='seg/dp2')
net2 = tf_util.conv2d(net2, 128, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay,
bn=True, is_training=is_training, scope='seg/conv3', weight_decay=weight_decay, is_dist=True)
net2 = tf_util.conv2d(net2, part_num, [1,1], padding='VALID', stride=[1,1], activation_fn=None,
bn=False, scope='seg/conv4', weight_decay=weight_decay, is_dist=True)

net2 = tf.reshape(net2, [batch_size, num_point, part_num])

return net2


def get_loss(seg_pred, seg):
per_instance_seg_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=seg_pred, labels=seg), axis=1)
seg_loss = tf.reduce_mean(per_instance_seg_loss)
per_instance_seg_pred_res = tf.argmax(seg_pred, 2)

return seg_loss, per_instance_seg_loss, per_instance_seg_pred_res






Loading

0 comments on commit 666d2b2

Please sign in to comment.