-
Notifications
You must be signed in to change notification settings - Fork 51
/
render.py
96 lines (75 loc) · 3.54 KB
/
render.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import os
import json
from tqdm import tqdm
from lib.models.street_gaussian_model import StreetGaussianModel
from lib.models.street_gaussian_renderer import StreetGaussianRenderer
from lib.datasets.dataset import Dataset
from lib.models.scene import Scene
from lib.utils.general_utils import safe_state
from lib.config import cfg
from lib.visualizers.base_visualizer import BaseVisualizer as Visualizer
from lib.visualizers.street_gaussian_visualizer import StreetGaussianVisualizer
import time
def render_sets():
cfg.render.save_image = True
cfg.render.save_video = False
with torch.no_grad():
dataset = Dataset()
gaussians = StreetGaussianModel(dataset.scene_info.metadata)
scene = Scene(gaussians=gaussians, dataset=dataset)
renderer = StreetGaussianRenderer()
times = []
if not cfg.eval.skip_train:
save_dir = os.path.join(cfg.model_path, 'train', "ours_{}".format(scene.loaded_iter))
visualizer = Visualizer(save_dir)
cameras = scene.getTrainCameras()
for idx, camera in enumerate(tqdm(cameras, desc="Rendering Training View")):
torch.cuda.synchronize()
start_time = time.time()
result = renderer.render(camera, gaussians)
torch.cuda.synchronize()
end_time = time.time()
times.append((end_time - start_time) * 1000)
visualizer.visualize(result, camera)
if not cfg.eval.skip_test:
save_dir = os.path.join(cfg.model_path, 'test', "ours_{}".format(scene.loaded_iter))
visualizer = Visualizer(save_dir)
cameras = scene.getTestCameras()
for idx, camera in enumerate(tqdm(cameras, desc="Rendering Testing View")):
torch.cuda.synchronize()
start_time = time.time()
result = renderer.render(camera, gaussians)
torch.cuda.synchronize()
end_time = time.time()
times.append((end_time - start_time) * 1000)
visualizer.visualize(result, camera)
print(times)
print('average rendering time: ', sum(times[1:]) / len(times[1:]))
def render_trajectory():
cfg.render.save_image = False
cfg.render.save_video = True
with torch.no_grad():
dataset = Dataset()
gaussians = StreetGaussianModel(dataset.scene_info.metadata)
scene = Scene(gaussians=gaussians, dataset=dataset)
renderer = StreetGaussianRenderer()
save_dir = os.path.join(cfg.model_path, 'trajectory', "ours_{}".format(scene.loaded_iter))
visualizer = StreetGaussianVisualizer(save_dir)
train_cameras = scene.getTrainCameras()
test_cameras = scene.getTestCameras()
cameras = train_cameras + test_cameras
cameras = list(sorted(cameras, key=lambda x: x.id))
for idx, camera in enumerate(tqdm(cameras, desc="Rendering Trajectory")):
result = renderer.render_all(camera, gaussians)
visualizer.visualize(result, camera)
visualizer.summarize()
if __name__ == "__main__":
print("Rendering " + cfg.model_path)
safe_state(cfg.eval.quiet)
if cfg.mode == 'evaluate':
render_sets()
elif cfg.mode == 'trajectory':
render_trajectory()
else:
raise NotImplementedError()