-
Notifications
You must be signed in to change notification settings - Fork 35
/
test.py
119 lines (101 loc) · 5.23 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import sys
import numpy as np
import argparse
import time
import torch
from torch.utils.data import DataLoader
from graspnetAPI.graspnet_eval import GraspGroup, GraspNetEval
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
from models.graspnet import GraspNet, pred_decode
from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn
from utils.collision_detector import ModelFreeCollisionDetector
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', default=None, required=True)
parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None, required=True)
parser.add_argument('--dump_dir', help='Dump dir to save outputs', default=None, required=True)
parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim')
parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]')
parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 15000]')
parser.add_argument('--batch_size', type=int, default=1, help='Batch Size during inference [default: 1]')
parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size for sparse convolution')
parser.add_argument('--collision_thresh', type=float, default=0.01,
help='Collision Threshold in collision detection [default: 0.01]')
parser.add_argument('--voxel_size_cd', type=float, default=0.01, help='Voxel Size for collision detection')
parser.add_argument('--infer', action='store_true', default=False)
parser.add_argument('--eval', action='store_true', default=False)
cfgs = parser.parse_args()
# ------------------------------------------------------------------------- GLOBAL CONFIG BEG
if not os.path.exists(cfgs.dump_dir):
os.mkdir(cfgs.dump_dir)
# Init datasets and dataloaders
def my_worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
pass
def inference():
test_dataset = GraspNetDataset(cfgs.dataset_root, split='test_seen', camera=cfgs.camera, num_points=cfgs.num_point,
voxel_size=cfgs.voxel_size, remove_outlier=True, augment=False, load_label=False)
print('Test dataset length: ', len(test_dataset))
scene_list = test_dataset.scene_list()
test_dataloader = DataLoader(test_dataset, batch_size=cfgs.batch_size, shuffle=False,
num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn)
print('Test dataloader length: ', len(test_dataloader))
# Init the model
net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# Load checkpoint
checkpoint = torch.load(cfgs.checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])
start_epoch = checkpoint['epoch']
print("-> loaded checkpoint %s (epoch: %d)" % (cfgs.checkpoint_path, start_epoch))
batch_interval = 100
net.eval()
tic = time.time()
for batch_idx, batch_data in enumerate(test_dataloader):
for key in batch_data:
if 'list' in key:
for i in range(len(batch_data[key])):
for j in range(len(batch_data[key][i])):
batch_data[key][i][j] = batch_data[key][i][j].to(device)
else:
batch_data[key] = batch_data[key].to(device)
# Forward pass
with torch.no_grad():
end_points = net(batch_data)
grasp_preds = pred_decode(end_points)
# Dump results for evaluation
for i in range(cfgs.batch_size):
data_idx = batch_idx * cfgs.batch_size + i
preds = grasp_preds[i].detach().cpu().numpy()
gg = GraspGroup(preds)
# collision detection
if cfgs.collision_thresh > 0:
cloud = test_dataset.get_data(data_idx, return_raw_cloud=True)
mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size_cd)
collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh)
gg = gg[~collision_mask]
# save grasps
save_dir = os.path.join(cfgs.dump_dir, scene_list[data_idx], cfgs.camera)
save_path = os.path.join(save_dir, str(data_idx % 256).zfill(4) + '.npy')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
gg.save_npy(save_path)
if (batch_idx + 1) % batch_interval == 0:
toc = time.time()
print('Eval batch: %d, time: %fs' % (batch_idx + 1, (toc - tic) / batch_interval))
tic = time.time()
def evaluate(dump_dir):
ge = GraspNetEval(root=cfgs.dataset_root, camera=cfgs.camera, split='test_seen')
res, ap = ge.eval_seen(dump_folder=dump_dir, proc=6)
save_dir = os.path.join(cfgs.dump_dir, 'ap_{}.npy'.format(cfgs.camera))
np.save(save_dir, res)
if __name__ == '__main__':
if cfgs.infer:
inference()
if cfgs.eval:
evaluate(cfgs.dump_dir)