forked from vLAR-group/DM-NeRF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_replica.py
53 lines (43 loc) · 2 KB
/
test_replica.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from datasets.loader_replica import *
from config import create_nerf, initial
from networks.tester import render_test
from networks.manipulator import manipulator_demo
def test():
model_coarse.eval()
model_fine.eval()
args.is_train = False
with torch.no_grad():
print('Rendering......')
testsavedir = os.path.join(args.basedir, args.expname, args.log_time,
'render_test_{:06d}'.format(iteration))
os.makedirs(testsavedir, exist_ok=True)
# seleted = [178,179]
mathed_file = os.path.join(testsavedir, 'matching_log.txt')
# render_test(position_embedder, view_embedder, model_coarse, model_fine, poses[seleted], hwk, args,
# gt_imgs=images[seleted], gt_labels=instances[seleted], ins_rgbs=ins_colors, savedir=testsavedir,
# matched_file=mathed_file)
render_test(position_embedder, view_embedder, model_coarse, model_fine, poses, hwk, args,
gt_imgs=images, gt_labels=instances, ins_rgbs=ins_colors, savedir=testsavedir,
matched_file=mathed_file)
print('Rendering Done', testsavedir)
return
if __name__ == '__main__':
args = initial()
# load data
images, poses, hwk, i_split, instances, ins_colors, args.ins_num = load_data(args)
print('Load data from', args.datadir)
H, W, K = hwk
i_train, i_test = i_split
position_embedder, view_embedder, model_coarse, model_fine, args = create_nerf(args)
ckpt_path = os.path.join(args.basedir, args.expname, args.log_time, args.test_model)
print('Reloading from', ckpt_path)
ckpt = torch.load(ckpt_path)
iteration = ckpt['iteration']
# Load model
model_coarse.load_state_dict(ckpt['network_coarse_state_dict'])
model_fine.load_state_dict(ckpt['network_fine_state_dict'])
images = torch.Tensor(images[i_test])
instances = torch.Tensor(instances[i_test]).type(torch.int16)
poses = torch.Tensor(poses[i_test])
args.perturb = False
test()