diff --git a/part_seg/part_seg_model.py b/part_seg/part_seg_model.py index 19a9a12..227f9a6 100644 --- a/part_seg/part_seg_model.py +++ b/part_seg/part_seg_model.py @@ -18,7 +18,7 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ num_point = point_cloud.get_shape()[1].value input_image = tf.expand_dims(point_cloud, -1) - k = 30 + k = 20 adj = tf_util.pairwise_distance(point_cloud) nn_idx = tf_util.knn(adj, k=k) @@ -27,6 +27,7 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ with tf.variable_scope('transform_net1') as sc: transform = input_transform_net(edge_feature, is_training, bn_decay, K=3, is_dist=True) point_cloud_transformed = tf.matmul(point_cloud, transform) + input_image = tf.expand_dims(point_cloud_transformed, -1) adj = tf_util.pairwise_distance(point_cloud_transformed) nn_idx = tf_util.knn(adj, k=k) @@ -42,57 +43,56 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv2', bn_decay=bn_decay, is_dist=True) - net_max_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) - net_mean_1 = tf.reduce_mean(out2, axis=-2, keep_dims=True) + net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) + + - out3 = tf_util.conv2d(tf.concat([net_max_1, net_mean_1], axis=-1), 64, [1,1], + adj = tf_util.pairwise_distance(net_1) + nn_idx = tf_util.knn(adj, k=k) + edge_feature = tf_util.get_edge_feature(net_1, nn_idx=nn_idx, k=k) + + out3 = tf_util.conv2d(edge_feature, 64, [1,1], padding='VALID', stride=[1,1], bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv3', bn_decay=bn_decay, is_dist=True) - adj = tf_util.pairwise_distance(tf.squeeze(out3, axis=-2)) - nn_idx = tf_util.knn(adj, k=k) - edge_feature = tf_util.get_edge_feature(out3, nn_idx=nn_idx, k=k) - - out4 = tf_util.conv2d(edge_feature, 64, [1,1], + out4 = tf_util.conv2d(out3, 64, [1,1], padding='VALID', stride=[1,1], bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv4', bn_decay=bn_decay, is_dist=True) - net_max_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) - net_mean_2 = tf.reduce_mean(out4, axis=-2, keep_dims=True) + net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) + + - out5 = tf_util.conv2d(tf.concat([net_max_2, net_mean_2], axis=-1), 64, [1,1], + adj = tf_util.pairwise_distance(net_2) + nn_idx = tf_util.knn(adj, k=k) + edge_feature = tf_util.get_edge_feature(net_2, nn_idx=nn_idx, k=k) + + out5 = tf_util.conv2d(edge_feature, 64, [1,1], padding='VALID', stride=[1,1], bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv5', bn_decay=bn_decay, is_dist=True) - adj = tf_util.pairwise_distance(tf.squeeze(out5, axis=-2)) - nn_idx = tf_util.knn(adj, k=k) - edge_feature = tf_util.get_edge_feature(out5, nn_idx=nn_idx, k=k) + # out6 = tf_util.conv2d(out5, 64, [1,1], + # padding='VALID', stride=[1,1], + # bn=True, is_training=is_training, weight_decay=weight_decay, + # scope='adj_conv6', bn_decay=bn_decay, is_dist=True) - out6 = tf_util.conv2d(edge_feature, 64, [1,1], - padding='VALID', stride=[1,1], - bn=True, is_training=is_training, weight_decay=weight_decay, - scope='adj_conv6', bn_decay=bn_decay, is_dist=True) + net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True) - net_max_3 = tf.reduce_max(out6, axis=-2, keep_dims=True) - net_mean_3 = tf.reduce_mean(out6, axis=-2, keep_dims=True) - out7 = tf_util.conv2d(tf.concat([net_max_3, net_mean_3], axis=-1), 64, [1,1], - padding='VALID', stride=[1,1], - bn=True, is_training=is_training, weight_decay=weight_decay, - scope='adj_conv7', bn_decay=bn_decay, is_dist=True) - out8 = tf_util.conv2d(tf.concat([out3, out5, out7], axis=-1), 1024, [1, 1], + out7 = tf_util.conv2d(tf.concat([net_1, net_2, net_3], axis=-1), 1024, [1, 1], padding='VALID', stride=[1,1], bn=True, is_training=is_training, - scope='adj_conv13', bn_decay=bn_decay, is_dist=True) + scope='adj_conv7', bn_decay=bn_decay, is_dist=True) + + out_max = tf_util.max_pool2d(out7, [num_point, 1], padding='VALID', scope='maxpool') - out_max = tf_util.max_pool2d(out8, [num_point, 1], padding='VALID', scope='maxpool') one_hot_label_expand = tf.reshape(input_label, [batch_size, 1, 1, cat_num]) - one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 128, [1, 1], + one_hot_label_expand = tf_util.conv2d(one_hot_label_expand, 64, [1, 1], padding='VALID', stride=[1,1], bn=True, is_training=is_training, scope='one_hot_label_expand', bn_decay=bn_decay, is_dist=True) @@ -100,16 +100,9 @@ def get_model(point_cloud, input_label, is_training, cat_num, part_num, \ expand = tf.tile(out_max, [1, num_point, 1, 1]) concat = tf.concat(axis=3, values=[expand, - net_max_1, - net_mean_1, - out3, - net_max_2, - net_mean_2, - out5, - net_max_3, - net_mean_3, - out7, - out8]) + net_1, + net_2, + net_3]) net2 = tf_util.conv2d(concat, 256, [1,1], padding='VALID', stride=[1,1], bn_decay=bn_decay, bn=True, is_training=is_training, scope='seg/conv1', weight_decay=weight_decay, is_dist=True) @@ -134,8 +127,3 @@ def get_loss(seg_pred, seg): return seg_loss, per_instance_seg_loss, per_instance_seg_pred_res - - - - - diff --git a/part_seg/train_multi_gpu.py b/part_seg/train_multi_gpu.py index 497dea7..bffdfae 100644 --- a/part_seg/train_multi_gpu.py +++ b/part_seg/train_multi_gpu.py @@ -378,7 +378,7 @@ def eval_one_epoch(epoch_num): train_one_epoch(train_file_idx, epoch) - if epoch % 10 == 0: + if epoch % 5 == 0: cp_filename = saver.save(sess, os.path.join(MODEL_STORAGE_PATH, 'epoch_' + str(epoch)+'.ckpt')) printout(flog, 'Successfully store the checkpoint model into ' + cp_filename)