-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_kitti.py
137 lines (105 loc) · 4.21 KB
/
test_kitti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# ----------------------------------------------------------------------------------------------------------------------
#
# Callable script to test any model on kitti raw data
#
# ----------------------------------------------------------------------------------------------------------------------
#
# Imports and global variables
# \**********************************/
#
# Common libs
import time
import os
import numpy as np
import argparse
# My libs
from datasets.ShapeNetBenchmark2048 import ShapeNetBenchmark2048Dataset
from utils.config import Config
from utils.tester import ModelTester
from models.KPCN_model import KernelPointCompletionNetwork
# Datasets
from datasets.kitti import KittiDataset
def test_caller(path, step_ind, kitti_dataset_path, shapenet_dataset_path):
##########################
# Initiate the environment
##########################
# Choose which gpu to use
GPU_ID = '0'
# Set GPU visible device
os.environ['CUDA_VISIBLE_DEVICES'] = GPU_ID
# Disable warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
###########################
# Load the model parameters
###########################
# Load model parameters
config = Config()
config.load(path)
##################################
# Change model parameters for test
##################################
# Change parameters for the test here. For example, you can stop augmenting the input data.
# config.augment_noise = 0.0001
# config.augment_color = 1.0
# config.validation_size = 500
# config.batch_num = 10
##############
# Prepare Data
##############
print()
print('Dataset Preparation')
print('*******************')
# Initiate dataset configuration
dataset = KittiDataset(config.batch_num, config.num_input_points, kitti_dataset_path, shapenet_dataset_path)
dl0 = 0 # config.first_subsampling_dl
# Create subsample clouds of the models
dataset.load_subsampled_clouds(dl0)
# Initiate ShapeNet dataset for use as DB in MMD metric # TODO: for more efficiency, a car-only db could be used
shapenet2048_dataset = ShapeNetBenchmark2048Dataset(config.batch_num, config.num_input_points,
shapenet_dataset_path)
# Create subsample clouds of the models
shapenet2048_dataset.load_subsampled_clouds(dl0)
# Initialize test input pipeline
dataset.init_test_input_pipeline(config)
##############
# Define Model
##############
print('Creating Model')
print('**************\n')
t1 = time.time()
model = KernelPointCompletionNetwork(dataset.flat_inputs, config, args.double_fold)
# Find all snapshot in the chosen training folder
snap_path = os.path.join(path, 'snapshots')
snap_steps = [int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta']
# Find which snapshot to restore
if step_ind == -1:
chosen_step = np.sort(snap_steps)[step_ind]
else:
chosen_step = step_ind + 1
chosen_snap = os.path.join(path, 'snapshots', 'snap-{:d}'.format(chosen_step))
# Create a tester class
tester = ModelTester(model, restore_snap=chosen_snap)
t2 = time.time()
print('\n----------------')
print('Done in {:.1f} s'.format(t2 - t1))
print('----------------\n')
############
# Start test
############
print('Start Test')
print('**********\n')
tester.test_kitti_completion(model, dataset, shapenet2048_dataset)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--saving_path', help="model_log_file_path")
parser.add_argument('--snap', type=int, default=-1, help="snapshot to restore (-1 for latest snapshot)")
parser.add_argument('--kitti_dataset_path')
parser.add_argument('--shapenet_dataset_path')
parser.add_argument('--double_fold', action='store_true')
args = parser.parse_args()
chosen_log = args.saving_path
chosen_snapshot = args.snap
# Check if log exists
if not os.path.exists(chosen_log):
raise ValueError('The given log does not exists: ' + chosen_log)
test_caller(chosen_log, chosen_snapshot, args.kitti_dataset_path, args.shapenet_dataset_path)