forked from minghanz/monocon_na565
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_raw.py
65 lines (47 loc) · 1.87 KB
/
test_raw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import sys
import torch
import argparse
from tqdm.auto import tqdm
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from utils.visualizer import Visualizer
from model.detector import MonoConDetector
from utils.engine_utils import tprint, move_data_device
from dataset.kitti_raw_dataset import KITTIRawDataset
# Arguments
parser = argparse.ArgumentParser('MonoCon Tester for KITTI Raw Dataset')
parser.add_argument('--data_dir',
type=str,
help="Path where sequence images are saved")
parser.add_argument('--calib_file',
type=str,
help="Path to calibration file (.txt)")
parser.add_argument('--checkpoint_file',
type=str,
help="Path of the checkpoint file (.pth)")
parser.add_argument('--gpu_id', type=int, default=0, help="Index of GPU to use for testing")
parser.add_argument('--fps', type=int, default=25, help="FPS of the result video")
parser.add_argument('--save_dir',
type=str,
help="Path of the directory to save the inferenced video")
args = parser.parse_args()
# Main
# (1) Build Dataset
dataset = KITTIRawDataset(args.data_dir, args.calib_file)
# (2) Build Model
device = f'cuda:{args.gpu_id}'
detector = MonoConDetector()
detector.load_checkpoint(args.checkpoint_file)
detector.to(device)
detector.eval()
tprint(f"Checkpoint '{args.checkpoint_file}' is loaded to model.")
# (3) Inference
vis_results = []
with torch.no_grad():
for data in tqdm(dataset, desc="Collecting Results..."):
data = move_data_device(data, device)
vis_result = detector.batch_eval(data, get_vis_format=True)
vis_results.extend(vis_result)
# (4) Visualize
visualizer = Visualizer(dataset, vis_format=vis_results)
visualizer.export_as_video(args.save_dir, plot_items=['2d', '3d', 'bev'], fps=args.fps)