Skip to content

Commit

Permalink
modify segmentation model structure
Browse files Browse the repository at this point in the history
  • Loading branch information
syb7573330 committed Dec 10, 2018
1 parent d767b0e commit 07f9ee1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 45 deletions.
76 changes: 32 additions & 44 deletions part_seg/part_seg_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \
num_point = point_cloud.get_shape()[1].value
input_image = tf.expand_dims(point_cloud, -1)

k = 30
k = 20

adj = tf_util.pairwise_distance(point_cloud)
nn_idx = tf_util.knn(adj, k=k)
Expand All @@ -27,6 +27,7 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \
with tf.variable_scope('transform_net1') as sc:
transform = input_transform_net(edge_feature, is_training, bn_decay, K=3, is_dist=True)
point_cloud_transformed = tf.matmul(point_cloud, transform)

input_image = tf.expand_dims(point_cloud_transformed, -1)
adj = tf_util.pairwise_distance(point_cloud_transformed)
nn_idx = tf_util.knn(adj, k=k)
Expand All @@ -42,74 +43,66 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv2', bn_decay=bn_decay, is_dist=True)

net_max_1 = tf.reduce_max(out2, axis=-2, keep_dims=True)
net_mean_1 = tf.reduce_mean(out2, axis=-2, keep_dims=True)
net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True)



out3 = tf_util.conv2d(tf.concat([net_max_1, net_mean_1], axis=-1), 64, [1,1],
adj = tf_util.pairwise_distance(net_1)
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(net_1, nn_idx=nn_idx, k=k)

out3 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv3', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out3, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out3, nn_idx=nn_idx, k=k)

out4 = tf_util.conv2d(edge_feature, 64, [1,1],
out4 = tf_util.conv2d(out3, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv4', bn_decay=bn_decay, is_dist=True)

net_max_2 = tf.reduce_max(out4, axis=-2, keep_dims=True)
net_mean_2 = tf.reduce_mean(out4, axis=-2, keep_dims=True)
net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True)



out5 = tf_util.conv2d(tf.concat([net_max_2, net_mean_2], axis=-1), 64, [1,1],
adj = tf_util.pairwise_distance(net_2)
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(net_2, nn_idx=nn_idx, k=k)

out5 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv5', bn_decay=bn_decay, is_dist=True)

adj = tf_util.pairwise_distance(tf.squeeze(out5, axis=-2))
nn_idx = tf_util.knn(adj, k=k)
edge_feature = tf_util.get_edge_feature(out5, nn_idx=nn_idx, k=k)
# out6 = tf_util.conv2d(out5, 64, [1,1],
# padding='VALID', stride=[1,1],
# bn=True, is_training=is_training, weight_decay=weight_decay,
# scope='adj_conv6', bn_decay=bn_decay, is_dist=True)

out6 = tf_util.conv2d(edge_feature, 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv6', bn_decay=bn_decay, is_dist=True)
net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True)

net_max_3 = tf.reduce_max(out6, axis=-2, keep_dims=True)
net_mean_3 = tf.reduce_mean(out6, axis=-2, keep_dims=True)

out7 = tf_util.conv2d(tf.concat([net_max_3, net_mean_3], axis=-1), 64, [1,1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv7', bn_decay=bn_decay, is_dist=True)

out8 = tf_util.conv2d(tf.concat([out3, out5, out7], axis=-1), 1024, [1, 1],
out7 = tf_util.conv2d(tf.concat([net_1, net_2, net_3], axis=-1), 1024, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='adj_conv13', bn_decay=bn_decay, is_dist=True)
scope='adj_conv7', bn_decay=bn_decay, is_dist=True)

out_max = tf_util.max_pool2d(out7, [num_point, 1], padding='VALID', scope='maxpool')

out_max = tf_util.max_pool2d(out8, [num_point, 1], padding='VALID', scope='maxpool')

one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num])
one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 128, [1, 1],
one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 64, [1, 1],
padding='VALID', stride=[1,1],
bn=True, is_training=is_training,
scope='one_hot_label_expand', bn_decay=bn_decay, is_dist=True)
out_max = tf.concat(axis=3, values=[out_max, one_hot_label_expand])
expand = tf.tile(out_max, [1, num_point, 1, 1])

concat = tf.concat(axis=3, values=[expand,
net_max_1,
net_mean_1,
out3,
net_max_2,
net_mean_2,
out5,
net_max_3,
net_mean_3,
out7,
out8])
net_1,
net_2,
net_3])

net2 = tf_util.conv2d(concat, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay,
bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay, is_dist=True)
Expand All @@ -134,8 +127,3 @@ def get_loss(seg_pred, seg):

return seg_loss, per_instance_seg_loss, per_instance_seg_pred_res






2 changes: 1 addition & 1 deletion part_seg/train_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def eval_one_epoch(epoch_num):

train_one_epoch(train_file_idx, epoch)

if epoch % 10 == 0:
if epoch % 5 == 0:
cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch)+'.ckpt'))
printout(flog, 'Successfully store the checkpoint model into ' + cp_filename)

Expand Down

0 comments on commit 07f9ee1

Please sign in to comment.