Skip to content

Commit

Permalink
Update 3DGS examples to use the latest nerfview from pypi (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
hangg7 authored Jun 8, 2024
1 parent c7b0a38 commit 07d9188
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 36 deletions.
4 changes: 2 additions & 2 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

# pycolmap for data parsing
git+https://github.com/rmbrualla/pycolmap@cc7ea4b7301720ac29287dbe450952511b32125e
# nerfview for viewer
git+https://github.com/hangg7/nerfview@4dde5291debd21ba33d768d9a8193aca87fc38fd
# (optional) nerfacc for torch version rasterization
# git+https://github.com/nerfstudio-project/nerfacc

viser
nerfview==0.0.2
imageio[ffmpeg]
numpy
scikit-learn
Expand Down
37 changes: 18 additions & 19 deletions examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
import torch.nn.functional as F
import tqdm
import tyro
import viser
import nerfview
from datasets.colmap import Dataset, Parser
from datasets.traj import generate_interpolated_path
from nerfview import VIEWER_LOCK, CameraState, ViewerServer
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
Expand Down Expand Up @@ -305,7 +306,12 @@ def __init__(self, cfg: Config) -> None:

# Viewer
if not self.cfg.disable_viewer:
self.server = ViewerServer(port=cfg.port, render_fn=self._viewer_render_fn)
self.server = viser.ViserServer(port=cfg.port, verbose=False)
self.viewer = nerfview.Viewer(
server=self.server,
render_fn=self._viewer_render_fn,
mode="training",
)

# Running stats for prunning & growing.
n_gauss = len(self.splats["means3d"])
Expand Down Expand Up @@ -401,9 +407,9 @@ def train(self):
pbar = tqdm.tqdm(range(init_step, max_steps))
for step in pbar:
if not cfg.disable_viewer:
while self.server.state.status == "paused":
while self.viewer.state.status == "paused":
time.sleep(0.01)
VIEWER_LOCK.acquire()
self.viewer.lock.acquire()
tic = time.time()

try:
Expand Down Expand Up @@ -624,15 +630,15 @@ def train(self):
self.render_traj(step)

if not cfg.disable_viewer:
VIEWER_LOCK.release()
self.viewer.lock.release()
num_train_steps_per_sec = 1.0 / (time.time() - tic)
num_train_rays_per_sec = (
num_train_rays_per_step * num_train_steps_per_sec
)
# Update the viewer state.
self.server.state.num_train_rays_per_sec = num_train_rays_per_sec
self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
# Update the scene.
self.server.update(step, num_train_rays_per_step)
self.viewer.update(step, num_train_rays_per_step)

@torch.no_grad()
def update_running_stats(self, info: Dict):
Expand Down Expand Up @@ -909,20 +915,13 @@ def render_traj(self, step: int):
print(f"Video saved to {video_dir}/traj_{step}.mp4")

@torch.no_grad()
def _viewer_render_fn(self, camera_state: CameraState, img_wh: Tuple[int, int]):
def _viewer_render_fn(
self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
):
"""Callable function for the viewer."""
fov = camera_state.fov
c2w = camera_state.c2w
W, H = img_wh

focal_length = H / 2.0 / np.tan(fov / 2.0)
K = np.array(
[
[focal_length, 0.0, W / 2.0],
[0.0, focal_length, H / 2.0],
[0.0, 0.0, 1.0],
]
)
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
c2w = torch.from_numpy(c2w).float().to(self.device)
K = torch.from_numpy(K).float().to(self.device)

Expand Down
26 changes: 11 additions & 15 deletions examples/simple_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
from typing import Tuple

import imageio
import nerfview
import numpy as np
import torch
import torch.nn.functional as F
from nerfview import CameraState, ViewerServer

import viser
from gsplat._helper import load_test_data
from gsplat.rendering import rasterization

Expand Down Expand Up @@ -135,19 +135,10 @@

# register and open viewer
@torch.no_grad()
def viewer_render_fn(camera_state: CameraState, img_wh: Tuple[int, int]):
fov = camera_state.fov
c2w = camera_state.c2w
def viewer_render_fn(camera_state: nerfview.CameraState, img_wh: Tuple[int, int]):
width, height = img_wh

focal_length = height / 2.0 / np.tan(fov / 2.0)
K = np.array(
[
[focal_length, 0.0, width / 2.0],
[0.0, focal_length, height / 2.0],
[0.0, 0.0, 1.0],
]
)
c2w = camera_state.c2w
K = camera_state.get_K(img_wh)
c2w = torch.from_numpy(c2w).float().to(device)
K = torch.from_numpy(K).float().to(device)
viewmat = c2w.inverse()
Expand Down Expand Up @@ -184,6 +175,11 @@ def viewer_render_fn(camera_state: CameraState, img_wh: Tuple[int, int]):
return render_rgbs


server = ViewerServer(port=args.port, render_fn=viewer_render_fn, mode="rendering")
server = viser.ViserServer(port=args.port, verbose=False)
_ = nerfview.Viewer(
server=server,
render_fn=viewer_render_fn,
mode="rendering",
)
print("Viewer running... Ctrl+C to exit.")
time.sleep(100000)

0 comments on commit 07d9188

Please sign in to comment.