Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with tensorflow 2, python3.8 #72

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def getDataFiles(list_filename):
return [line.rstrip() for line in open(list_filename)]

def load_h5(h5_filename):
f = h5py.File(h5_filename)
f = h5py.File(h5_filename, 'r')
data = f['data'][:]
label = f['label'][:]
return (data, label)
Expand Down
16 changes: 9 additions & 7 deletions tensorflow/sem_seg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
import tf_util

def placeholder_inputs(batch_size, num_point):
pointclouds_pl = tf.placeholder(tf.float32,
pointclouds_pl = tf.compat.v1.placeholder(tf.float32,
shape=(batch_size, num_point, 9))
labels_pl = tf.placeholder(tf.int32,
labels_pl = tf.compat.v1.placeholder(tf.int32,
shape=(batch_size, num_point))
return pointclouds_pl, labels_pl

def get_model(point_cloud, is_training, bn_decay=None):
""" ConvNet baseline, input is BxNx9 gray image """
batch_size = point_cloud.get_shape()[0].value
num_point = point_cloud.get_shape()[1].value
batch_size = point_cloud.get_shape()[0]
num_point = point_cloud.get_shape()[1]
input_image = tf.expand_dims(point_cloud, -1)

k = 20

weight_decay = 0.0

adj = tf_util.pairwise_distance(point_cloud[:, :, 6:])
nn_idx = tf_util.knn(adj, k=k) # (batch, num_points, k)
edge_feature = tf_util.get_edge_feature(input_image, nn_idx=nn_idx, k=k)
Expand All @@ -39,7 +41,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv2', bn_decay=bn_decay, is_dist=True)

net_1 = tf.reduce_max(out2, axis=-2, keep_dims=True)
net_1 = tf.compat.v1.reduce_max(out2, axis=-2, keep_dims=True)



Expand All @@ -57,7 +59,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
bn=True, is_training=is_training, weight_decay=weight_decay,
scope='adj_conv4', bn_decay=bn_decay, is_dist=True)

net_2 = tf.reduce_max(out4, axis=-2, keep_dims=True)
net_2 = tf.compat.v1.reduce_max(out4, axis=-2, keep_dims=True)



Expand All @@ -75,7 +77,7 @@ def get_model(point_cloud, is_training, bn_decay=None):
# bn=True, is_training=is_training, weight_decay=weight_decay,
# scope='adj_conv6', bn_decay=bn_decay, is_dist=True)

net_3 = tf.reduce_max(out5, axis=-2, keep_dims=True)
net_3 = tf.compat.v1.reduce_max(out5, axis=-2, keep_dims=True)



Expand Down
64 changes: 33 additions & 31 deletions tensorflow/sem_seg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import math
import h5py
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
import socket

Expand All @@ -17,11 +19,11 @@
from model import *

parser = argparse.ArgumentParser()
parser.add_argument('--num_gpu', type=int, default=2, help='the number of GPUs to use [default: 2]')
parser.add_argument('--num_gpu', type=int, default=1, help='the number of GPUs to use [default: 2]')
parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
parser.add_argument('--num_point', type=int, default=4096, help='Point number [default: 4096]')
parser.add_argument('--max_epoch', type=int, default=101, help='Epoch to run [default: 50]')
parser.add_argument('--batch_size', type=int, default=12, help='Batch Size during training for each GPU [default: 24]')
parser.add_argument('--batch_size', type=int, default=8, help='Batch Size during training for each GPU [default: 24]')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
Expand Down Expand Up @@ -61,7 +63,7 @@

ALL_FILES = provider.getDataFiles('indoor3d_sem_seg_hdf5_data/all_files.txt')
room_filelist = [line.rstrip() for line in open('indoor3d_sem_seg_hdf5_data/room_filelist.txt')]
print len(room_filelist)
print(len(room_filelist))

# Load ALL data
data_batch_list = []
Expand Down Expand Up @@ -99,7 +101,7 @@ def log_string(out_str):


def get_learning_rate(batch):
learning_rate = tf.train.exponential_decay(
learning_rate = tf.compat.v1.train.exponential_decay(
BASE_LEARNING_RATE, # Base learning rate.
batch * BATCH_SIZE, # Current index into the dataset.
DECAY_STEP, # Decay step.
Expand All @@ -109,7 +111,7 @@ def get_learning_rate(batch):
return learning_rate

def get_bn_decay(batch):
bn_momentum = tf.train.exponential_decay(
bn_momentum = tf.compat.v1.train.exponential_decay(
BN_INIT_DECAY,
batch*BATCH_SIZE,
BN_DECAY_DECAY_STEP,
Expand Down Expand Up @@ -162,20 +164,21 @@ def train():
learning_rate = get_learning_rate(batch)
tf.summary.scalar('learning_rate', learning_rate)

trainer = tf.train.AdamOptimizer(learning_rate)
trainer = tf.compat.v1.train.AdamOptimizer(learning_rate)
#trainer = tf.keras.optimizers.Adam(learning_rate)

tower_grads = []
pointclouds_phs = []
labels_phs = []
is_training_phs =[]

with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(FLAGS.num_gpu):
with tf.compat.v1.variable_scope(tf.compat.v1.get_variable_scope()):
for i in range(FLAGS.num_gpu):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % (TOWER_NAME, i)) as scope:

pointclouds_pl, labels_pl = placeholder_inputs(BATCH_SIZE, NUM_POINT)
is_training_pl = tf.placeholder(tf.bool, shape=())
is_training_pl = tf.compat.v1.placeholder(tf.bool, shape=())

pointclouds_phs.append(pointclouds_pl)
labels_phs.append(labels_pl)
Expand All @@ -185,11 +188,11 @@ def train():
loss = get_loss(pred, labels_phs[-1])
tf.summary.scalar('loss', loss)

correct = tf.equal(tf.argmax(pred, 2), tf.to_int64(labels_phs[-1]))
correct = tf.equal(tf.argmax(pred, 2), tf.compat.v1.to_int64(labels_phs[-1]))
accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE*NUM_POINT)
tf.summary.scalar('accuracy', accuracy)

tf.get_variable_scope().reuse_variables()
tf.compat.v1.get_variable_scope().reuse_variables()

grads = trainer.compute_gradients(loss)

Expand All @@ -199,23 +202,23 @@ def train():

train_op = trainer.apply_gradients(grads, global_step=batch)

saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10)
saver = tf.compat.v1.train.Saver(tf.compat.v1.global_variables(), sharded=True, max_to_keep=10)

# Create a session
config = tf.ConfigProto()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
sess = tf.compat.v1.Session(config=config)

# Add summary writers
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
merged = tf.compat.v1.summary.merge_all()
train_writer = tf.compat.v1.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))
test_writer = tf.compat.v1.summary.FileWriter(os.path.join(LOG_DIR, 'test'))

# Init variables for two GPUs
init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
# Init variables for one GPU
init = tf.group(tf.compat.v1.global_variables_initializer(),
tf.compat.v1.local_variables_initializer())
sess.run(init)

ops = {'pointclouds_phs': pointclouds_phs,
Expand Down Expand Up @@ -259,21 +262,20 @@ def train_one_epoch(sess, ops, train_writer):
print('Current batch/total batch num: %d/%d'%(batch_idx,num_batches))
start_idx_0 = batch_idx * BATCH_SIZE
end_idx_0 = (batch_idx+1) * BATCH_SIZE
start_idx_1 = (batch_idx+1) * BATCH_SIZE
end_idx_1 = (batch_idx+2) * BATCH_SIZE



feed_dict = {ops['pointclouds_phs'][0]: current_data[start_idx_0:end_idx_0, :, :],
ops['pointclouds_phs'][1]: current_data[start_idx_1:end_idx_1, :, :],
ops['labels_phs'][0]: current_label[start_idx_0:end_idx_0],
ops['labels_phs'][1]: current_label[start_idx_1:end_idx_1],
ops['is_training_phs'][0]: is_training,
ops['is_training_phs'][1]: is_training}
summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']],
ops['is_training_phs'][0]: is_training}

# summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'], ops['train_op'], ops['loss'], ops['pred']],
# feed_dict=feed_dict)

step, _, loss_val, pred_val = sess.run([ops['step'], ops['train_op'], ops['loss'], ops['pred']],
feed_dict=feed_dict)
train_writer.add_summary(summary, step)

# train_writer.add_summary(summary, step)
pred_val = np.argmax(pred_val, 2)
correct = np.sum(pred_val == current_label[start_idx_1:end_idx_1])
correct = np.sum(pred_val == current_label[start_idx_0:end_idx_0])
total_correct += correct
total_seen += (BATCH_SIZE*NUM_POINT)
loss_sum += loss_val
Expand Down
Loading