diff --git a/tensorflow/provider.py b/tensorflow/provider.py index 0005dbe..ef426aa 100644 --- a/tensorflow/provider.py +++ b/tensorflow/provider.py @@ -140,7 +140,7 @@ def getDataFiles(list_filename): return [line.rstrip() for line in open(list_filename)] def load_h5(h5_filename): - f = h5py.File(h5_filename) + f = h5py.File(h5_filename, 'r') data = f['data'][:] label = f['label'][:] return (data, label) diff --git a/tensorflow/sem_seg/model.py b/tensorflow/sem_seg/model.py index 2a4aeea..45b3695 100644 --- a/tensorflow/sem_seg/model.py +++ b/tensorflow/sem_seg/model.py @@ -11,20 +11,22 @@ import tf_util def placeholder_inputs(batch_size, num_point): - pointclouds_pl = tf.placeholder(tf.float32, + pointclouds_pl = tf.compat.v1.placeholder(tf.float32, shape=(batch_size, num_point, 9)) - labels_pl = tf.placeholder(tf.int32, + labels_pl = tf.compat.v1.placeholder(tf.int32, shape=(batch_size, num_point)) return pointclouds_pl, labels_pl def get_model(point_cloud, is_training, bn_decay=None): """ ConvNet baseline, input is BxNx9 gray image """ - batch_size = point_cloud.get_shape()[0].value - num_point = point_cloud.get_shape()[1].value + batch_size = point_cloud.get_shape()[0] + num_point = point_cloud.get_shape()[1] input_image = tf.expand_dims(point_cloud, -1) k = 20 + weight_decay = 0.0 + adj = tf_util.pairwise_distance(point_cloud[:, :, 6:]) nn_idx = tf_util.knn(adj, k=k) # (batch, num_points, k) edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k) @@ -39,7 +41,7 @@ def get_model(point_cloud, is_training, bn_decay=None): bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv2', bn_decay=bn_decay, is_dist=True) - net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True) + net_1 = tf.compat.v1.reduce_max(out2, axis=-2, keep_dims=True) @@ -57,7 +59,7 @@ def get_model(point_cloud, is_training, bn_decay=None): bn=True, is_training=is_training, weight_decay=weight_decay, scope='adj_conv4', bn_decay=bn_decay, is_dist=True) - net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True) + net_2 = tf.compat.v1.reduce_max(out4, axis=-2, keep_dims=True) @@ -75,7 +77,7 @@ def get_model(point_cloud, is_training, bn_decay=None): # bn=True, is_training=is_training, weight_decay=weight_decay, # scope='adj_conv6', bn_decay=bn_decay, is_dist=True) - net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True) + net_3 = tf.compat.v1.reduce_max(out5, axis=-2, keep_dims=True) diff --git a/tensorflow/sem_seg/train.py b/tensorflow/sem_seg/train.py index bdec182..16eaa85 100644 --- a/tensorflow/sem_seg/train.py +++ b/tensorflow/sem_seg/train.py @@ -2,6 +2,8 @@ import math import h5py import numpy as np +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf import socket @@ -17,11 +19,11 @@ from model import * parser = argparse.ArgumentParser() -parser.add_argument('--num_gpu', type=int, default=2, help='the number of GPUs to use [default: 2]') +parser.add_argument('--num_gpu', type=int, default=1, help='the number of GPUs to use [default: 2]') parser.add_argument('--log_dir', default='log', help='Log dir [default: log]') parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]') parser.add_argument('--max_epoch', type=int, default=101, help='Epoch to run [default: 50]') -parser.add_argument('--batch_size', type=int, default=12, help='Batch Size during training for each GPU [default: 24]') +parser.add_argument('--batch_size', type=int, default=8, help='Batch Size during training for each GPU [default: 24]') parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]') parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]') parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]') @@ -61,7 +63,7 @@ ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt') room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')] -print len(room_filelist) +print(len(room_filelist)) # Load ALL data data_batch_list = [] @@ -99,7 +101,7 @@ def log_string(out_str): def get_learning_rate(batch): - learning_rate = tf.train.exponential_decay( + learning_rate = tf.compat.v1.train.exponential_decay( BASE_LEARNING_RATE, # Base learning rate. batch * BATCH_SIZE, # Current index into the dataset. DECAY_STEP, # Decay step. @@ -109,7 +111,7 @@ def get_learning_rate(batch): return learning_rate def get_bn_decay(batch): - bn_momentum = tf.train.exponential_decay( + bn_momentum = tf.compat.v1.train.exponential_decay( BN_INIT_DECAY, batch*BATCH_SIZE, BN_DECAY_DECAY_STEP, @@ -162,20 +164,21 @@ def train(): learning_rate = get_learning_rate(batch) tf.summary.scalar('learning_rate', learning_rate) - trainer = tf.train.AdamOptimizer(learning_rate) + trainer = tf.compat.v1.train.AdamOptimizer(learning_rate) + #trainer = tf.keras.optimizers.Adam(learning_rate) tower_grads = [] pointclouds_phs = [] labels_phs = [] is_training_phs =[] - with tf.variable_scope(tf.get_variable_scope()): - for i in xrange(FLAGS.num_gpu): + with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()): + for i in range(FLAGS.num_gpu): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope: pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT) - is_training_pl = tf.placeholder(tf.bool, shape=()) + is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=()) pointclouds_phs.append(pointclouds_pl) labels_phs.append(labels_pl) @@ -185,11 +188,11 @@ def train(): loss = get_loss(pred, labels_phs[-1]) tf.summary.scalar('loss', loss) - correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_phs[-1])) + correct = tf.equal(tf.argmax(pred, 2), tf.compat.v1.to_int64(labels_phs[-1])) accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT) tf.summary.scalar('accuracy', accuracy) - tf.get_variable_scope().reuse_variables() + tf.compat.v1.get_variable_scope().reuse_variables() grads = trainer.compute_gradients(loss) @@ -199,23 +202,23 @@ def train(): train_op = trainer.apply_gradients(grads, global_step=batch) - saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10) + saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10) # Create a session - config = tf.ConfigProto() + config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.allow_soft_placement = True - sess = tf.Session(config=config) + sess = tf.compat.v1.Session(config=config) # Add summary writers - merged = tf.summary.merge_all() - train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), + merged = tf.compat.v1.summary.merge_all() + train_writer = tf.compat.v1.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph) - test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test')) + test_writer = tf.compat.v1.summary.FileWriter(os.path.join(LOG_DIR, 'test')) - # Init variables for two GPUs - init = tf.group(tf.global_variables_initializer(), - tf.local_variables_initializer()) + # Init variables for one GPU + init = tf.group(tf.compat.v1.global_variables_initializer(), + tf.compat.v1.local_variables_initializer()) sess.run(init) ops = {'pointclouds_phs': pointclouds_phs, @@ -259,21 +262,20 @@ def train_one_epoch(sess, ops, train_writer): print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches)) start_idx_0 = batch_idx * BATCH_SIZE end_idx_0 = (batch_idx+1) * BATCH_SIZE - start_idx_1 = (batch_idx+1) * BATCH_SIZE - end_idx_1 = (batch_idx+2) * BATCH_SIZE - - + feed_dict = {ops['pointclouds_phs'][0]: current_data[start_idx_0:end_idx_0, :, :], - ops['pointclouds_phs'][1]: current_data[start_idx_1:end_idx_1, :, :], ops['labels_phs'][0]: current_label[start_idx_0:end_idx_0], - ops['labels_phs'][1]: current_label[start_idx_1:end_idx_1], - ops['is_training_phs'][0]: is_training, - ops['is_training_phs'][1]: is_training} - summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], + ops['is_training_phs'][0]: is_training} + + # summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']], + # feed_dict=feed_dict) + + step, _, loss_val, pred_val = sess.run([ops['step'], ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict) - train_writer.add_summary(summary, step) + + # train_writer.add_summary(summary, step) pred_val = np.argmax(pred_val, 2) - correct = np.sum(pred_val == current_label[start_idx_1:end_idx_1]) + correct = np.sum(pred_val == current_label[start_idx_0:end_idx_0]) total_correct += correct total_seen += (BATCH_SIZE*NUM_POINT) loss_sum += loss_val diff --git a/tensorflow/utils/tf_util.py b/tensorflow/utils/tf_util.py index 22ef62c..c0e6a07 100644 --- a/tensorflow/utils/tf_util.py +++ b/tensorflow/utils/tf_util.py @@ -20,7 +20,7 @@ def _variable_on_cpu(name, shape, initializer, use_fp16=False, trainable=True): """ with tf.device('/cpu:0'): dtype = tf.float16 if use_fp16 else tf.float32 - var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable) + var = tf.compat.v1.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable) return var def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True): @@ -41,13 +41,13 @@ def _variable_with_weight_decay(name, shape, stddev, wd, use_xavier=True): Variable Tensor """ if use_xavier: - initializer = tf.contrib.layers.xavier_initializer() + initializer = tf.initializers.GlorotUniform() else: initializer = tf.truncated_normal_initializer(stddev=stddev) var = _variable_on_cpu(name, shape, initializer) if wd is not None: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') - tf.add_to_collection('losses', weight_decay) + tf.compat.v1.add_to_collection('losses', weight_decay) return var @@ -85,8 +85,8 @@ def conv1d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: - num_in_channels = inputs.get_shape()[-1].value + with tf.compat.v1.variable_scope(scope) as sc: + num_in_channels = inputs.get_shape()[-1] kernel_shape = [kernel_size, num_in_channels, num_output_channels] kernel = _variable_with_weight_decay('weights', @@ -146,9 +146,9 @@ def conv2d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_h, kernel_w = kernel_size - num_in_channels = inputs.get_shape()[-1].value + num_in_channels = inputs.get_shape()[-1] kernel_shape = [kernel_h, kernel_w, num_in_channels, num_output_channels] kernel = _variable_with_weight_decay('weights', @@ -209,9 +209,9 @@ def conv2d_transpose(inputs, Note: conv2d(conv2d_transpose(a, num_out, ksize, stride), a.shape[-1], ksize, stride) == a """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_h, kernel_w = kernel_size - num_in_channels = inputs.get_shape()[-1].value + num_in_channels = inputs.get_shape()[-1] kernel_shape = [kernel_h, kernel_w, num_output_channels, num_in_channels] # reversed to conv2d kernel = _variable_with_weight_decay('weights', @@ -230,9 +230,9 @@ def get_deconv_dim(dim_size, stride_size, kernel_size, padding): return dim_size # caculate output shape - batch_size = inputs.get_shape()[0].value - height = inputs.get_shape()[1].value - width = inputs.get_shape()[2].value + batch_size = inputs.get_shape()[0] + height = inputs.get_shape()[1] + width = inputs.get_shape()[2] out_height = get_deconv_dim(height, stride_h, kernel_h, padding) out_width = get_deconv_dim(width, stride_w, kernel_w, padding) output_shape = [batch_size, out_height, out_width, num_output_channels] @@ -288,9 +288,9 @@ def conv3d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_d, kernel_h, kernel_w = kernel_size - num_in_channels = inputs.get_shape()[-1].value + num_in_channels = inputs.get_shape()[-1] kernel_shape = [kernel_d, kernel_h, kernel_w, num_in_channels, num_output_channels] kernel = _variable_with_weight_decay('weights', @@ -334,8 +334,8 @@ def fully_connected(inputs, Returns: Variable tensor of size B x num_outputs. """ - with tf.variable_scope(scope) as sc: - num_input_units = inputs.get_shape()[-1].value + with tf.compat.v1.variable_scope(scope) as sc: + num_input_units = inputs.get_shape()[-1] weights = _variable_with_weight_decay('weights', shape=[num_input_units, num_outputs], use_xavier=use_xavier, @@ -369,7 +369,7 @@ def max_pool2d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_h, kernel_w = kernel_size stride_h, stride_w = stride outputs = tf.nn.max_pool(inputs, @@ -394,7 +394,7 @@ def avg_pool2d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_h, kernel_w = kernel_size stride_h, stride_w = stride outputs = tf.nn.avg_pool(inputs, @@ -420,7 +420,7 @@ def max_pool3d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_d, kernel_h, kernel_w = kernel_size stride_d, stride_h, stride_w = stride outputs = tf.nn.max_pool3d(inputs, @@ -445,7 +445,7 @@ def avg_pool3d(inputs, Returns: Variable tensor """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: kernel_d, kernel_h, kernel_w = kernel_size stride_d, stride_h, stride_w = stride outputs = tf.nn.avg_pool3d(inputs, @@ -472,8 +472,8 @@ def batch_norm_template(inputs, is_training, scope, moments_dims, bn_decay): Return: normed: batch-normalized maps """ - with tf.variable_scope(scope) as sc: - num_channels = inputs.get_shape()[-1].value + with tf.compat.v1.variable_scope(scope) as sc: + num_channels = inputs.get_shape()[-1] beta = tf.Variable(tf.constant(0.0, shape=[num_channels]), name='beta', trainable=True) gamma = tf.Variable(tf.constant(1.0, shape=[num_channels]), @@ -510,8 +510,8 @@ def batch_norm_dist_template(inputs, is_training, scope, moments_dims, bn_decay) Return: normed: batch-normalized maps """ - with tf.variable_scope(scope) as sc: - num_channels = inputs.get_shape()[-1].value + with tf.compat.v1.variable_scope(scope) as sc: + num_channels = inputs.get_shape()[-1] beta = _variable_on_cpu('beta', [num_channels], initializer=tf.zeros_initializer()) gamma = _variable_on_cpu('gamma', [num_channels], initializer=tf.ones_initializer()) @@ -521,8 +521,8 @@ def batch_norm_dist_template(inputs, is_training, scope, moments_dims, bn_decay) def train_bn_op(): batch_mean, batch_var = tf.nn.moments(inputs, moments_dims, name='moments') decay = bn_decay if bn_decay is not None else 0.9 - train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) - train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) + train_mean = tf.compat.v1.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) + train_var = tf.compat.v1.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) with tf.control_dependencies([train_mean, train_var]): return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, gamma, 1e-3) @@ -628,7 +628,7 @@ def dropout(inputs, Returns: tensor variable """ - with tf.variable_scope(scope) as sc: + with tf.compat.v1.variable_scope(scope) as sc: outputs = tf.cond(is_training, lambda: tf.nn.dropout(inputs, keep_prob, noise_shape), lambda: inputs) @@ -652,7 +652,7 @@ def pairwise_distance(point_cloud): point_cloud_transpose = tf.transpose(point_cloud, perm=[0, 2, 1]) point_cloud_inner = tf.matmul(point_cloud, point_cloud_transpose) point_cloud_inner = -2*point_cloud_inner - point_cloud_square = tf.reduce_sum(tf.square(point_cloud), axis=-1, keep_dims=True) + point_cloud_square = tf.compat.v1.reduce_sum(tf.square(point_cloud), axis=-1, keep_dims=True) point_cloud_square_tranpose = tf.transpose(point_cloud_square, perm=[0, 2, 1]) return point_cloud_square + point_cloud_inner + point_cloud_square_tranpose @@ -689,9 +689,9 @@ def get_edge_feature(point_cloud, nn_idx, k=20): point_cloud_central = point_cloud point_cloud_shape = point_cloud.get_shape() - batch_size = point_cloud_shape[0].value - num_points = point_cloud_shape[1].value - num_dims = point_cloud_shape[2].value + batch_size = point_cloud_shape[0] + num_points = point_cloud_shape[1] + num_dims = point_cloud_shape[2] idx_ = tf.range(batch_size) * num_points idx_ = tf.reshape(idx_, [batch_size, 1, 1])