From d31aa2b3a2e3f75d89cf9c719c378e7aae79215e Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 2 Jul 2024 14:35:01 -0700 Subject: [PATCH 01/66] use inria cuda to train --- examples/simple_trainer_mcmc.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 8c6a21924..5775acd87 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -25,9 +25,8 @@ set_random_seed, ) from gsplat import quat_scale_to_covar_preci -from gsplat.rendering import rasterization +from gsplat.rendering import rasterization, rasterization_inria_wrapper from gsplat.relocation import compute_relocation -from gsplat.cuda_legacy._torch_impl import scale_rot_to_cov3d from simple_trainer import create_splats_with_optimizers @@ -62,9 +61,9 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + save_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -273,9 +272,9 @@ def rasterize_splats( **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means3d"] # [N, 3] - # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] + # quats = self.splats["quats"] # [N, 4] scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] @@ -293,7 +292,7 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - render_colors, render_alphas, info = rasterization( + render_colors, render_alphas, info = rasterization_inria_wrapper( means=means, quats=quats, scales=scales, @@ -405,8 +404,6 @@ def train(self): bkgd = torch.rand(1, 3, device=device) colors = colors + bkgd * (1.0 - alphas) - info["means2d"].retain_grad() # used for running stats - # loss l1loss = F.l1_loss(colors, pixels) ssimloss = 1.0 - self.ssim( @@ -696,9 +693,8 @@ def eval(self, step: int): # write images canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() - imageio.imwrite( - f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) - ) + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite(f"{self.render_dir}/val_step{step:04d}_{i:04d}.png", canvas) pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] @@ -764,15 +760,13 @@ def render_traj(self, step: int): far_plane=cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] - colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] - depths = renders[0, ..., 3:4] # [H, W, 1] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) # write images - canvas = torch.cat( - [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 - ) - canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas = torch.cat([colors, depths.repeat(1, 1, 1, 3)], dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) canvas_all.append(canvas) # save to video From 6a58dc3791e332abaac718b76047edbb8740f387 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 2 Jul 2024 14:54:58 -0700 Subject: [PATCH 02/66] canvas list --- examples/benchmark_mcmc.sh | 63 +++++++++++++++++++++++++++++++++ examples/simple_trainer_mcmc.py | 12 ++++--- 2 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 examples/benchmark_mcmc.sh diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh new file mode 100644 index 000000000..da126dc17 --- /dev/null +++ b/examples/benchmark_mcmc.sh @@ -0,0 +1,63 @@ +RESULT_DIR=results/mcmc_sfm_inria + +# for SCENE in bicycle bonsai counter garden kitchen room stump; +for SCENE in bonsai counter kitchen room bicycle garden stump; +do + if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then + DATA_FACTOR=4 + else + DATA_FACTOR=2 + fi + + if [ "$SCENE" = "bonsai" ]; then + CAP_MAX=1300000 + elif [ "$SCENE" = "counter" ]; then + CAP_MAX=1200000 + elif [ "$SCENE" = "kitchen" ]; then + CAP_MAX=1800000 + elif [ "$SCENE" = "room" ]; then + CAP_MAX=1500000 + else + CAP_MAX=3000000 + fi + + echo "Running $SCENE" + + # train without eval + python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # # run eval and render + # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + # do + # python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \ + # --data_dir data/360_v2/$SCENE/ \ + # --result_dir $RESULT_DIR/$SCENE/ \ + # --ckpt $CKPT + # done +done + + +for SCENE in bicycle bonsai counter garden kitchen room stump; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val*; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train*; + do + echo $STATS + cat $STATS; + echo + done +done \ No newline at end of file diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 5775acd87..b22a1ccd4 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -692,7 +692,8 @@ def eval(self, step: int): ellipse_time += time.time() - tic # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + canvas_list = [pixels, colors] + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() canvas = (canvas * 255).astype(np.uint8) imageio.imwrite(f"{self.render_dir}/val_step{step:04d}_{i:04d}.png", canvas) @@ -761,11 +762,14 @@ def render_traj(self, step: int): render_mode="RGB+ED", ) # [1, H, W, 4] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] - depths = renders[..., 3:4] # [1, H, W, 1] - depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors] + if renders.shape[-1] == 4: + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(depths.repeat(1, 1, 1, 3)) # write images - canvas = torch.cat([colors, depths.repeat(1, 1, 1, 3)], dim=2).squeeze(0).cpu().numpy() + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() canvas = (canvas * 255).astype(np.uint8) canvas_all.append(canvas) From 1b6e094687ad8ed6b6a7103d6008be53991787b6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 2 Jul 2024 18:20:28 -0700 Subject: [PATCH 03/66] 2dgs --- examples/benchmark_mcmc.sh | 4 +-- examples/simple_trainer_mcmc.py | 35 ++++++++++++++++--- gsplat/point_utils.py | 30 ++++++++++++++++ gsplat/rendering.py | 62 +++++++++++++++++++++++++++++---- 4 files changed, 119 insertions(+), 12 deletions(-) create mode 100644 gsplat/point_utils.py diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index da126dc17..83368ac47 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,7 +1,7 @@ -RESULT_DIR=results/mcmc_sfm_inria +RESULT_DIR=results/mcmc_sfm_inria_2dgs # for SCENE in bicycle bonsai counter garden kitchen room stump; -for SCENE in bonsai counter kitchen room bicycle garden stump; +for SCENE in garden treehill bonsai counter kitchen room bicycle stump; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then DATA_FACTOR=4 diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index b22a1ccd4..75bf6082c 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -81,6 +81,8 @@ class Config: init_scale: float = 0.1 # Weight for SSIM loss ssim_lambda: float = 0.2 + lambda_normal: float = 0.05 + lambda_dist: float = 0.0 # Near plane clipping distance near_plane: float = 0.01 @@ -306,6 +308,7 @@ def rasterize_splats( absgrad=self.cfg.absgrad, sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, + mode="2dgs", **kwargs, ) return render_colors, render_alphas, info @@ -410,6 +413,19 @@ def train(self): pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + + # regularization + lambda_normal = cfg.lambda_normal if step > 7000 else 0.0 + lambda_dist = cfg.lambda_dist if step > 3000 else 0.0 + + rend_dist = info["rend_dist"] + rend_normal = info['rend_normal'] + surf_normal = info['surf_normal'] + normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None] + normal_loss = (normal_error).mean() + dist_loss = (rend_dist).mean() + loss += lambda_normal * normal_loss + lambda_dist * dist_loss + if cfg.depth_loss: # query depths from depth map points = torch.stack( @@ -456,6 +472,8 @@ def train(self): self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/normal_loss", normal_loss.item(), step) + self.writer.add_scalar("train/dist_loss", dist_loss.item(), step) self.writer.add_scalar( "train/num_GS", len(self.splats["means3d"]), step ) @@ -678,7 +696,7 @@ def eval(self, step: int): torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + colors, _, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -688,11 +706,15 @@ def eval(self, step: int): far_plane=cfg.far_plane, ) # [1, H, W, 3] colors = torch.clamp(colors, 0.0, 1.0) + rend_normals = info["rend_normal"].permute(1,2,0).unsqueeze(0) + surf_normals = info["surf_normal"].permute(1,2,0).unsqueeze(0) + rend_normals = rend_normals * -0.5 + 0.5 + surf_normals = surf_normals * -0.5 + 0.5 torch.cuda.synchronize() ellipse_time += time.time() - tic # write images - canvas_list = [pixels, colors] + canvas_list = [pixels, colors, rend_normals, surf_normals] canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() canvas = (canvas * 255).astype(np.uint8) imageio.imwrite(f"{self.render_dir}/val_step{step:04d}_{i:04d}.png", canvas) @@ -751,7 +773,7 @@ def render_traj(self, step: int): canvas_all = [] for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): - renders, _, _ = self.rasterize_splats( + renders, _, info = self.rasterize_splats( camtoworlds=camtoworlds[i : i + 1], Ks=K[None], width=width, @@ -762,7 +784,12 @@ def render_traj(self, step: int): render_mode="RGB+ED", ) # [1, H, W, 4] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] - canvas_list = [colors] + rend_normals = info["rend_normal"].permute(1,2,0).unsqueeze(0) + surf_normals = info["surf_normal"].permute(1,2,0).unsqueeze(0) + rend_normals = rend_normals * -0.5 + 0.5 + surf_normals = surf_normals * -0.5 + 0.5 + + canvas_list = [colors, rend_normals, surf_normals] if renders.shape[-1] == 4: depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) diff --git a/gsplat/point_utils.py b/gsplat/point_utils.py new file mode 100644 index 000000000..6ddbcc2c3 --- /dev/null +++ b/gsplat/point_utils.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +def depths_to_points(depthmap, world_view_transform, full_proj_transform): + c2w = (world_view_transform.T).inverse() + H, W = depthmap.shape[-2:] + ndc2pix = torch.tensor([ + [W / 2, 0, 0, (W) / 2], + [0, H / 2, 0, (H) / 2], + [0, 0, 0, 1]]).float().cuda().T + projection_matrix = c2w.T @ full_proj_transform + intrins = (projection_matrix @ ndc2pix)[:3,:3].T + + grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') + points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) + rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T + rays_o = c2w[:3,3] + points = depthmap.reshape(-1, 1) * rays_d + rays_o + return points + +def depth_to_normal(depth, world_view_transform, full_proj_transform): + points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape(*depth.shape[1:], 3) + output = torch.zeros_like(points) + dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) + dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) + normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + output[1:-1, 1:-1, :] = normal_map + return output diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 7339e5b2f..3e144afab 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -12,6 +12,7 @@ rasterize_to_pixels, spherical_harmonics, ) +from .point_utils import depth_to_normal def rasterization( @@ -507,6 +508,7 @@ def rasterization_inria_wrapper( eps2d: float = 0.3, sh_degree: Optional[int] = None, backgrounds: Optional[Tensor] = None, + mode: str = "3dgs", **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: """Wrapper for Inria's rasterization backend. @@ -521,10 +523,17 @@ def rasterization_inria_wrapper( https://github.com/graphdeco-inria/diff-gaussian-rasterization """ - from diff_gaussian_rasterization import ( - GaussianRasterizationSettings, - GaussianRasterizer, - ) + if mode == "3dgs": + from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) + elif mode == "2dgs": + from diff_surfel_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) + scales = scales[:, :2] def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): tanHalfFovY = math.tan((fovY / 2)) @@ -602,7 +611,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): _colors.shape[0], 3 - _colors.shape[-1], device=device ) _colors = torch.cat([_colors, pad], dim=-1) - _render_colors_, radii = rasterizer( + _render_colors_, _, allmap = rasterizer( means3D=means, means2D=means2D, shs=_colors if colors.dim() == 3 else None, @@ -621,4 +630,45 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) - return render_colors, None, {} + + # additional regularizations + render_alpha = allmap[1:2] + + # get normal map + # transform normal from view space to world space + render_normal = allmap[2:5] + render_normal = (render_normal.permute(1,2,0) @ (world_view_transform[:3,:3].T)).permute(2,0,1) + + # get median depth map + render_depth_median = allmap[5:6] + render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) + + # get expected depth map + render_depth_expected = allmap[0:1] + render_depth_expected = (render_depth_expected / render_alpha) + render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) + + # get depth distortion map + render_dist = allmap[6:7] + + # psedo surface attributes + # surf depth is either median or expected by setting depth_ratio to 1 or 0 + # for bounded scene, use median depth, i.e., depth_ratio = 1; + # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. + depth_ratio = 0 + surf_depth = render_depth_expected * (1-depth_ratio) + (depth_ratio) * render_depth_median + + # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. + surf_normal = depth_to_normal(surf_depth, world_view_transform, full_proj_transform) + surf_normal = surf_normal.permute(2,0,1) + # remember to multiply with accum_alpha since render_normal is unnormalized. + surf_normal = surf_normal * (render_alpha).detach() + + info = { + 'rend_alpha': render_alpha, + 'rend_normal': render_normal, + 'rend_dist': render_dist, + 'surf_depth': surf_depth, + 'surf_normal': surf_normal, + } + return render_colors, None, info From 55b39bbef4145dd0ee6c31517ebaae040d6f0f8f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 2 Jul 2024 20:50:02 -0700 Subject: [PATCH 04/66] lambda_dist --- examples/benchmark_mcmc.sh | 6 +++--- examples/simple_trainer_mcmc.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index 83368ac47..e4d23fa0c 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,9 +1,9 @@ RESULT_DIR=results/mcmc_sfm_inria_2dgs # for SCENE in bicycle bonsai counter garden kitchen room stump; -for SCENE in garden treehill bonsai counter kitchen room bicycle stump; +for SCENE in treehill garden flowers bonsai counter kitchen room bicycle stump; do - if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ]; then + if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then DATA_FACTOR=4 else DATA_FACTOR=2 @@ -18,7 +18,7 @@ do elif [ "$SCENE" = "room" ]; then CAP_MAX=1500000 else - CAP_MAX=3000000 + CAP_MAX=2000000 fi echo "Running $SCENE" diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 75bf6082c..39ae39419 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -82,7 +82,7 @@ class Config: # Weight for SSIM loss ssim_lambda: float = 0.2 lambda_normal: float = 0.05 - lambda_dist: float = 0.0 + lambda_dist: float = 100 # 1000 for bounded scenes, 100 for unbounded scenes # Near plane clipping distance near_plane: float = 0.01 From f4ea0e0d77c2ba3f215689ab09cdd1920bdb67b4 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 3 Jul 2024 11:49:18 -0700 Subject: [PATCH 05/66] clean up --- examples/benchmark_mcmc.sh | 6 +- examples/simple_trainer_mcmc.py | 87 ++++++++------- gsplat/point_utils.py | 4 +- gsplat/rendering.py | 182 +++++++++++++++++++++++++++----- 4 files changed, 207 insertions(+), 72 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index e4d23fa0c..fd70eea09 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,12 +1,14 @@ -RESULT_DIR=results/mcmc_sfm_inria_2dgs +RESULT_DIR=results/mcmc_sfm_2dgs # for SCENE in bicycle bonsai counter garden kitchen room stump; for SCENE in treehill garden flowers bonsai counter kitchen room bicycle stump; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then DATA_FACTOR=4 + DIST_LAMBDA=100 else DATA_FACTOR=2 + DIST_LAMBDA=1000 fi if [ "$SCENE" = "bonsai" ]; then @@ -25,6 +27,8 @@ do # train without eval python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 2dgs \ + --dist_lambda $DIST_LAMBDA \ --init_type sfm \ --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 39ae39419..4a585bb44 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -25,13 +25,15 @@ set_random_seed, ) from gsplat import quat_scale_to_covar_preci -from gsplat.rendering import rasterization, rasterization_inria_wrapper +from gsplat.rendering import rasterization, rasterization_2dgs_inria_wrapper from gsplat.relocation import compute_relocation from simple_trainer import create_splats_with_optimizers @dataclass class Config: + # Model type can be 3dgs or 2dgs + model_type: str = "3dgs" # Disable viewer disable_viewer: bool = False # Path to the .pt file. If provide, it will skip training and render a video @@ -81,8 +83,6 @@ class Config: init_scale: float = 0.1 # Weight for SSIM loss ssim_lambda: float = 0.2 - lambda_normal: float = 0.05 - lambda_dist: float = 100 # 1000 for bounded scenes, 100 for unbounded scenes # Near plane clipping distance near_plane: float = 0.01 @@ -139,6 +139,13 @@ class Config: depth_loss: bool = False # Weight for depth loss depth_lambda: float = 1e-2 + + # Enable normal loss. (experimental) + normal_loss: bool = True + # Weight for normal loss + normal_lambda: float = 0.05 + # Weight for distortion loss. Use 100 for unbounded scenes and 1000 for bounded scenes + dist_lambda: float = 100 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -163,6 +170,12 @@ def __init__(self, cfg: Config) -> None: self.cfg = cfg self.device = "cuda" + if cfg.model_type == "3dgs": + self.rasterization_fn = rasterization + elif cfg.model_type == "2dgs": + self.rasterization_fn = rasterization_2dgs_inria_wrapper + else: + raise ValueError(f"Unsupported model type: {cfg.model_type}") # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -274,9 +287,8 @@ def rasterize_splats( **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means3d"] # [N, 3] - quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] # rasterization does normalization internally - # quats = self.splats["quats"] # [N, 4] + quats = self.splats["quats"] # [N, 4] scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] @@ -294,7 +306,8 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - render_colors, render_alphas, info = rasterization_inria_wrapper( + + render_colors, render_alphas, info = self.rasterization_fn( means=means, quats=quats, scales=scales, @@ -413,18 +426,8 @@ def train(self): pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda - - # regularization - lambda_normal = cfg.lambda_normal if step > 7000 else 0.0 - lambda_dist = cfg.lambda_dist if step > 3000 else 0.0 - - rend_dist = info["rend_dist"] - rend_normal = info['rend_normal'] - surf_normal = info['surf_normal'] - normal_error = (1 - (rend_normal * surf_normal).sum(dim=0))[None] - normal_loss = (normal_error).mean() - dist_loss = (rend_dist).mean() - loss += lambda_normal * normal_loss + lambda_dist * dist_loss + loss += cfg.opacity_reg * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + loss += cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() if cfg.depth_loss: # query depths from depth map @@ -445,16 +448,16 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda + + if cfg.normal_loss: + normal_lambda = cfg.normal_lambda if step > 7000 else 0.0 + dist_lambda = cfg.dist_lambda if step > 3000 else 0.0 - loss = ( - loss - + cfg.opacity_reg - * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() - ) - loss = ( - loss - + cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() - ) + rend_normal = info['rend_normal'] + surf_normal = info['surf_normal'] + normalloss = (1 - (rend_normal * surf_normal).sum(dim=-1)).mean() + distloss = (info["rend_dist"]).mean() + loss += normal_lambda * normalloss + dist_lambda * distloss loss.backward() @@ -472,14 +475,15 @@ def train(self): self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) - self.writer.add_scalar("train/normal_loss", normal_loss.item(), step) - self.writer.add_scalar("train/dist_loss", dist_loss.item(), step) self.writer.add_scalar( "train/num_GS", len(self.splats["means3d"]), step ) self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.normal_loss: + self.writer.add_scalar("train/normalloss", normalloss.item(), step) + self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) @@ -705,16 +709,17 @@ def eval(self, step: int): near_plane=cfg.near_plane, far_plane=cfg.far_plane, ) # [1, H, W, 3] - colors = torch.clamp(colors, 0.0, 1.0) - rend_normals = info["rend_normal"].permute(1,2,0).unsqueeze(0) - surf_normals = info["surf_normal"].permute(1,2,0).unsqueeze(0) - rend_normals = rend_normals * -0.5 + 0.5 - surf_normals = surf_normals * -0.5 + 0.5 torch.cuda.synchronize() ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + if self.cfg.normal_loss: + rend_normals = info["rend_normal"] * 0.5 + 0.5 + surf_normals = info["surf_normal"] * 0.5 + 0.5 + canvas_list.extend([rend_normals, surf_normals]) # write images - canvas_list = [pixels, colors, rend_normals, surf_normals] canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() canvas = (canvas * 255).astype(np.uint8) imageio.imwrite(f"{self.render_dir}/val_step{step:04d}_{i:04d}.png", canvas) @@ -783,17 +788,17 @@ def render_traj(self, step: int): far_plane=cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] - colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] - rend_normals = info["rend_normal"].permute(1,2,0).unsqueeze(0) - surf_normals = info["surf_normal"].permute(1,2,0).unsqueeze(0) - rend_normals = rend_normals * -0.5 + 0.5 - surf_normals = surf_normals * -0.5 + 0.5 - canvas_list = [colors, rend_normals, surf_normals] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + canvas_list = [colors] if renders.shape[-1] == 4: depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths.repeat(1, 1, 1, 3)) + if self.cfg.normal_loss: + rend_normals = info["rend_normal"] * 0.5 + 0.5 + surf_normals = info["surf_normal"] * 0.5 + 0.5 + canvas_list.extend([rend_normals, surf_normals]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/point_utils.py b/gsplat/point_utils.py index 6ddbcc2c3..6e5926a9d 100644 --- a/gsplat/point_utils.py +++ b/gsplat/point_utils.py @@ -5,7 +5,7 @@ def depths_to_points(depthmap, world_view_transform, full_proj_transform): c2w = (world_view_transform.T).inverse() - H, W = depthmap.shape[-2:] + H, W = depthmap.shape[:2] ndc2pix = torch.tensor([ [W / 2, 0, 0, (W) / 2], [0, H / 2, 0, (H) / 2], @@ -21,7 +21,7 @@ def depths_to_points(depthmap, world_view_transform, full_proj_transform): return points def depth_to_normal(depth, world_view_transform, full_proj_transform): - points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape(*depth.shape[1:], 3) + points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape(*depth.shape[:2], 3) output = torch.zeros_like(points) dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 3e144afab..2b4980d8b 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -3,6 +3,7 @@ import torch from torch import Tensor +import torch.nn.functional as F from typing_extensions import Literal from .cuda._wrapper import ( @@ -508,7 +509,6 @@ def rasterization_inria_wrapper( eps2d: float = 0.3, sh_degree: Optional[int] = None, backgrounds: Optional[Tensor] = None, - mode: str = "3dgs", **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: """Wrapper for Inria's rasterization backend. @@ -523,17 +523,141 @@ def rasterization_inria_wrapper( https://github.com/graphdeco-inria/diff-gaussian-rasterization """ - if mode == "3dgs": - from diff_gaussian_rasterization import ( - GaussianRasterizationSettings, - GaussianRasterizer, + from diff_gaussian_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) + + def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4, device=device) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" + C = len(viewmats) + device = means.device + channels = colors.shape[-1] + + render_colors = [] + for cid in range(C): + FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) + FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) + tanfovx = math.tan(FoVx * 0.5) + tanfovy = math.tan(FoVy * 0.5) + + world_view_transform = viewmats[cid].transpose(0, 1) + projection_matrix = _getProjectionMatrix( + znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device + ).transpose(0, 1) + full_proj_transform = ( + world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) + ).squeeze(0) + camera_center = world_view_transform.inverse()[3, :3] + + background = ( + backgrounds[cid] + if backgrounds is not None + else torch.zeros(3, device=device) ) - elif mode == "2dgs": - from diff_surfel_rasterization import ( - GaussianRasterizationSettings, - GaussianRasterizer, + + raster_settings = GaussianRasterizationSettings( + image_height=height, + image_width=width, + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=background, + scale_modifier=1.0, + viewmatrix=world_view_transform, + projmatrix=full_proj_transform, + sh_degree=0 if sh_degree is None else sh_degree, + campos=camera_center, + prefiltered=False, + debug=False, ) - scales = scales[:, :2] + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means2D = torch.zeros_like(means, requires_grad=True, device=device) + + render_colors_ = [] + for i in range(0, channels, 3): + _colors = colors[..., i : i + 3] + if _colors.shape[-1] < 3: + pad = torch.zeros( + _colors.shape[0], 3 - _colors.shape[-1], device=device + ) + _colors = torch.cat([_colors, pad], dim=-1) + _render_colors_, _ = rasterizer( + means3D=means, + means2D=means2D, + shs=_colors if colors.dim() == 3 else None, + colors_precomp=_colors if colors.dim() == 2 else None, + opacities=opacities[:, None], + scales=scales, + rotations=quats, + cov3D_precomp=None, + ) + if _colors.shape[-1] < 3: + _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] + render_colors_.append(_render_colors_) + render_colors_ = torch.cat(render_colors_, dim=-1) + + render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] + + render_colors.append(render_colors_) + render_colors = torch.stack(render_colors, dim=0) + return render_colors, None, {} + +def rasterization_2dgs_inria_wrapper( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + colors: Tensor, # [N, D] or [N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 100.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + backgrounds: Optional[Tensor] = None, + **kwargs, +) -> Tuple[Tensor, Tensor, Dict]: + """Wrapper for Inria's rasterization backend. + + .. warning:: + This function exists for comparision purpose only. Only rendered image is + returned. + + .. warning:: + Inria's CUDA backend has its own LICENSE, so this function should be used with + the respect to the original LICENSE at: + https://github.com/graphdeco-inria/diff-gaussian-rasterization + + """ + from diff_surfel_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): tanHalfFovY = math.tan((fovY / 2)) @@ -561,6 +685,9 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): C = len(viewmats) device = means.device channels = colors.shape[-1] + + scales = scales[:, :2] # [N, 2] + quats = quats = F.normalize(quats, dim=-1) # [N, 4] render_colors = [] for cid in range(C): @@ -627,29 +754,29 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors_ = torch.cat(render_colors_, dim=-1) render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] - render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) - # additional regularizations - render_alpha = allmap[1:2] + # additional maps + allmaps = allmaps.permute(1, 2, 0) # [H, W, C] + render_alphas = allmap[..., 1:2] # get normal map # transform normal from view space to world space - render_normal = allmap[2:5] - render_normal = (render_normal.permute(1,2,0) @ (world_view_transform[:3,:3].T)).permute(2,0,1) + render_normal = allmap[..., 2:5] + render_normal = render_normal @ (world_view_transform[:3,:3].T) # get median depth map - render_depth_median = allmap[5:6] + render_depth_median = allmap[..., 5:6] render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) # get expected depth map - render_depth_expected = allmap[0:1] - render_depth_expected = (render_depth_expected / render_alpha) + render_depth_expected = allmap[..., 0:1] + render_depth_expected = (render_depth_expected / render_alphas) render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) # get depth distortion map - render_dist = allmap[6:7] + render_dist = allmap[..., 6:7] # psedo surface attributes # surf depth is either median or expected by setting depth_ratio to 1 or 0 @@ -660,15 +787,14 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. surf_normal = depth_to_normal(surf_depth, world_view_transform, full_proj_transform) - surf_normal = surf_normal.permute(2,0,1) # remember to multiply with accum_alpha since render_normal is unnormalized. - surf_normal = surf_normal * (render_alpha).detach() + surf_normal = surf_normal * (render_alphas).detach() - info = { - 'rend_alpha': render_alpha, - 'rend_normal': render_normal, - 'rend_dist': render_dist, - 'surf_depth': surf_depth, - 'surf_normal': surf_normal, + render_alphas = render_alphas.unsqueeze(0) + meta = { + 'rend_normal': render_normal.unsqueeze(0), + 'rend_dist': render_dist.unsqueeze(0), + 'surf_depth': surf_depth.unsqueeze(0), + 'surf_normal': surf_normal.unsqueeze(0), } - return render_colors, None, info + return render_colors, render_alphas, meta From 87748f43538a83edc5859924df288b9a17e42af9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 3 Jul 2024 11:49:36 -0700 Subject: [PATCH 06/66] format --- examples/simple_trainer_mcmc.py | 21 +++++++++++--------- gsplat/point_utils.py | 34 ++++++++++++++++++++++----------- gsplat/rendering.py | 33 +++++++++++++++++--------------- 3 files changed, 53 insertions(+), 35 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 4a585bb44..d66685971 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -139,7 +139,7 @@ class Config: depth_loss: bool = False # Weight for depth loss depth_lambda: float = 1e-2 - + # Enable normal loss. (experimental) normal_loss: bool = True # Weight for normal loss @@ -306,7 +306,7 @@ def rasterize_splats( colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - + render_colors, render_alphas, info = self.rasterization_fn( means=means, quats=quats, @@ -426,9 +426,12 @@ def train(self): pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda - loss += cfg.opacity_reg * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + loss += ( + cfg.opacity_reg + * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + ) loss += cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() - + if cfg.depth_loss: # query depths from depth map points = torch.stack( @@ -448,13 +451,13 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda - + if cfg.normal_loss: normal_lambda = cfg.normal_lambda if step > 7000 else 0.0 dist_lambda = cfg.dist_lambda if step > 3000 else 0.0 - rend_normal = info['rend_normal'] - surf_normal = info['surf_normal'] + rend_normal = info["rend_normal"] + surf_normal = info["surf_normal"] normalloss = (1 - (rend_normal * surf_normal).sum(dim=-1)).mean() distloss = (info["rend_dist"]).mean() loss += normal_lambda * normalloss + dist_lambda * distloss @@ -711,7 +714,7 @@ def eval(self, step: int): ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += time.time() - tic - + colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.normal_loss: @@ -788,7 +791,7 @@ def render_traj(self, step: int): far_plane=cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] - + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [colors] if renders.shape[-1] == 4: diff --git a/gsplat/point_utils.py b/gsplat/point_utils.py index 6e5926a9d..1b1269d17 100644 --- a/gsplat/point_utils.py +++ b/gsplat/point_utils.py @@ -3,25 +3,37 @@ import torch.nn.functional as F import numpy as np + def depths_to_points(depthmap, world_view_transform, full_proj_transform): c2w = (world_view_transform.T).inverse() H, W = depthmap.shape[:2] - ndc2pix = torch.tensor([ - [W / 2, 0, 0, (W) / 2], - [0, H / 2, 0, (H) / 2], - [0, 0, 0, 1]]).float().cuda().T + ndc2pix = ( + torch.tensor([[W / 2, 0, 0, (W) / 2], [0, H / 2, 0, (H) / 2], [0, 0, 0, 1]]) + .float() + .cuda() + .T + ) projection_matrix = c2w.T @ full_proj_transform - intrins = (projection_matrix @ ndc2pix)[:3,:3].T - - grid_x, grid_y = torch.meshgrid(torch.arange(W, device='cuda').float(), torch.arange(H, device='cuda').float(), indexing='xy') - points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape(-1, 3) - rays_d = points @ intrins.inverse().T @ c2w[:3,:3].T - rays_o = c2w[:3,3] + intrins = (projection_matrix @ ndc2pix)[:3, :3].T + + grid_x, grid_y = torch.meshgrid( + torch.arange(W, device="cuda").float(), + torch.arange(H, device="cuda").float(), + indexing="xy", + ) + points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( + -1, 3 + ) + rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T + rays_o = c2w[:3, 3] points = depthmap.reshape(-1, 1) * rays_d + rays_o return points + def depth_to_normal(depth, world_view_transform, full_proj_transform): - points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape(*depth.shape[:2], 3) + points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape( + *depth.shape[:2], 3 + ) output = torch.zeros_like(points) dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 2b4980d8b..65f73f06c 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -624,7 +624,8 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) return render_colors, None, {} - + + def rasterization_2dgs_inria_wrapper( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] @@ -685,7 +686,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): C = len(viewmats) device = means.device channels = colors.shape[-1] - + scales = scales[:, :2] # [N, 2] quats = quats = F.normalize(quats, dim=-1) # [N, 4] @@ -756,7 +757,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) - + # additional maps allmaps = allmaps.permute(1, 2, 0) # [H, W, C] render_alphas = allmap[..., 1:2] @@ -764,37 +765,39 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): # get normal map # transform normal from view space to world space render_normal = allmap[..., 2:5] - render_normal = render_normal @ (world_view_transform[:3,:3].T) - + render_normal = render_normal @ (world_view_transform[:3, :3].T) + # get median depth map render_depth_median = allmap[..., 5:6] render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) # get expected depth map render_depth_expected = allmap[..., 0:1] - render_depth_expected = (render_depth_expected / render_alphas) + render_depth_expected = render_depth_expected / render_alphas render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) - + # get depth distortion map render_dist = allmap[..., 6:7] # psedo surface attributes # surf depth is either median or expected by setting depth_ratio to 1 or 0 - # for bounded scene, use median depth, i.e., depth_ratio = 1; + # for bounded scene, use median depth, i.e., depth_ratio = 1; # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. depth_ratio = 0 - surf_depth = render_depth_expected * (1-depth_ratio) + (depth_ratio) * render_depth_median - + surf_depth = ( + render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median + ) + # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. surf_normal = depth_to_normal(surf_depth, world_view_transform, full_proj_transform) # remember to multiply with accum_alpha since render_normal is unnormalized. surf_normal = surf_normal * (render_alphas).detach() - + render_alphas = render_alphas.unsqueeze(0) meta = { - 'rend_normal': render_normal.unsqueeze(0), - 'rend_dist': render_dist.unsqueeze(0), - 'surf_depth': surf_depth.unsqueeze(0), - 'surf_normal': surf_normal.unsqueeze(0), + "rend_normal": render_normal.unsqueeze(0), + "rend_dist": render_dist.unsqueeze(0), + "surf_depth": surf_depth.unsqueeze(0), + "surf_normal": surf_normal.unsqueeze(0), } return render_colors, render_alphas, meta From 7ee6c20c1e2ce5d4d469a9f6f6eadd04cae4e55f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 3 Jul 2024 12:19:48 -0700 Subject: [PATCH 07/66] 3m garden first --- examples/benchmark_mcmc.sh | 4 ++-- examples/simple_trainer_mcmc.py | 12 ++++++------ gsplat/rendering.py | 21 +++++++-------------- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index fd70eea09..634526e89 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,7 +1,7 @@ RESULT_DIR=results/mcmc_sfm_2dgs # for SCENE in bicycle bonsai counter garden kitchen room stump; -for SCENE in treehill garden flowers bonsai counter kitchen room bicycle stump; +for SCENE in garden treehill bonsai counter kitchen room bicycle stump flowers; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then DATA_FACTOR=4 @@ -20,7 +20,7 @@ do elif [ "$SCENE" = "room" ]; then CAP_MAX=1500000 else - CAP_MAX=2000000 + CAP_MAX=3000000 fi echo "Running $SCENE" diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index d66685971..5761b9cae 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -63,9 +63,9 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 15_000, 30_000]) # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) + save_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 15_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -718,9 +718,9 @@ def eval(self, step: int): colors = torch.clamp(colors, 0.0, 1.0) canvas_list = [pixels, colors] if self.cfg.normal_loss: - rend_normals = info["rend_normal"] * 0.5 + 0.5 surf_normals = info["surf_normal"] * 0.5 + 0.5 - canvas_list.extend([rend_normals, surf_normals]) + rend_normals = info["rend_normal"] * 0.5 + 0.5 + canvas_list.extend([surf_normals, rend_normals]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -799,9 +799,9 @@ def render_traj(self, step: int): depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths.repeat(1, 1, 1, 3)) if self.cfg.normal_loss: - rend_normals = info["rend_normal"] * 0.5 + 0.5 surf_normals = info["surf_normal"] * 0.5 + 0.5 - canvas_list.extend([rend_normals, surf_normals]) + rend_normals = info["rend_normal"] * 0.5 + 0.5 + canvas_list.extend([surf_normals, rend_normals]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 65f73f06c..832f36fd0 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -643,17 +643,10 @@ def rasterization_2dgs_inria_wrapper( backgrounds: Optional[Tensor] = None, **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: - """Wrapper for Inria's rasterization backend. - - .. warning:: - This function exists for comparision purpose only. Only rendered image is - returned. - - .. warning:: - Inria's CUDA backend has its own LICENSE, so this function should be used with - the respect to the original LICENSE at: - https://github.com/graphdeco-inria/diff-gaussian-rasterization - + """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. + + Install the rasterization backend from + https://github.com/hbb1/diff-surfel-rasterization """ from diff_surfel_rasterization import ( GaussianRasterizationSettings, @@ -759,7 +752,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors = torch.stack(render_colors, dim=0) # additional maps - allmaps = allmaps.permute(1, 2, 0) # [H, W, C] + allmap = allmap.permute(1, 2, 0) # [H, W, C] render_alphas = allmap[..., 1:2] # get normal map @@ -795,9 +788,9 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_alphas = render_alphas.unsqueeze(0) meta = { + # "surf_depth": surf_depth.unsqueeze(0), + "surf_normal": surf_normal.unsqueeze(0), "rend_normal": render_normal.unsqueeze(0), "rend_dist": render_dist.unsqueeze(0), - "surf_depth": surf_depth.unsqueeze(0), - "surf_normal": surf_normal.unsqueeze(0), } return render_colors, render_alphas, meta From 143af517d4fd265cc6e908b973744a13b559fa90 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Wed, 3 Jul 2024 22:55:01 -0700 Subject: [PATCH 08/66] 2dgs_mcmc_sfm --- examples/benchmark_mcmc.sh | 2 +- examples/simple_trainer.py | 4 +++- examples/simple_trainer_mcmc.py | 1 - 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index 634526e89..e283abd69 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,4 +1,4 @@ -RESULT_DIR=results/mcmc_sfm_2dgs +RESULT_DIR=results/2dgs_mcmc_sfm # for SCENE in bicycle bonsai counter garden kitchen room stump; for SCENE in garden treehill bonsai counter kitchen room bicycle stump flowers; diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 2cf22fe02..985cc67af 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -184,7 +184,9 @@ def create_splats_with_optimizers( # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) - scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + scales = (dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + scales[:, 2] /= 100 + scales = torch.log(scales) quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 5761b9cae..81589cbab 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -321,7 +321,6 @@ def rasterize_splats( absgrad=self.cfg.absgrad, sparse_grad=self.cfg.sparse_grad, rasterize_mode=rasterize_mode, - mode="2dgs", **kwargs, ) return render_colors, render_alphas, info From 667031480d86750e8c3b747727135cd4bfb16266 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 4 Jul 2024 12:10:56 -0700 Subject: [PATCH 09/66] train 3dgs without normal loss --- examples/benchmark_mcmc.sh | 4 +-- examples/simple_trainer.py | 2 +- examples/simple_trainer_mcmc.py | 49 +++++++++++++++++++++----------- gsplat/point_utils.py | 50 +++++++++++++++++++++++++++++++-- gsplat/rendering.py | 14 ++++----- 5 files changed, 89 insertions(+), 30 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index e283abd69..fdd5f35b6 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,4 +1,4 @@ -RESULT_DIR=results/2dgs_mcmc_sfm +RESULT_DIR=results/3dgs_mcmc_sfm # for SCENE in bicycle bonsai counter garden kitchen room stump; for SCENE in garden treehill bonsai counter kitchen room bicycle stump flowers; @@ -27,7 +27,7 @@ do # train without eval python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 2dgs \ + --model_type 3dgs \ --dist_lambda $DIST_LAMBDA \ --init_type sfm \ --cap_max $CAP_MAX \ diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 985cc67af..ff7af003e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -185,7 +185,7 @@ def create_splats_with_optimizers( dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) scales = (dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] - scales[:, 2] /= 100 + # scales[:, 2] /= 100 scales = torch.log(scales) quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 81589cbab..95aa28926 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -27,6 +27,7 @@ from gsplat import quat_scale_to_covar_preci from gsplat.rendering import rasterization, rasterization_2dgs_inria_wrapper from gsplat.relocation import compute_relocation +from gsplat.point_utils import depth_to_normal from simple_trainer import create_splats_with_optimizers @@ -143,7 +144,10 @@ class Config: # Enable normal loss. (experimental) normal_loss: bool = True # Weight for normal loss - normal_lambda: float = 0.05 + normal_lambda: float = 0.0#5 + + # Enable distortion loss. (experimental) + dist_loss: bool = False # Weight for distortion loss. Use 100 for unbounded scenes and 1000 for bounded scenes dist_lambda: float = 100 @@ -408,7 +412,7 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode="RGB+ED", # if cfg.depth_loss else "RGB", ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -453,13 +457,18 @@ def train(self): if cfg.normal_loss: normal_lambda = cfg.normal_lambda if step > 7000 else 0.0 + normals = info["rend_normal"] + depths = info["rend_depth"] + normals_depth = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) + normals_depth = normals_depth * (alphas).detach() + + normalloss = (1 - (normals * normals_depth).sum(dim=-1)).mean() + loss += normal_lambda * normalloss + + if cfg.dist_loss: dist_lambda = cfg.dist_lambda if step > 3000 else 0.0 - - rend_normal = info["rend_normal"] - surf_normal = info["surf_normal"] - normalloss = (1 - (rend_normal * surf_normal).sum(dim=-1)).mean() distloss = (info["rend_dist"]).mean() - loss += normal_lambda * normalloss + dist_lambda * distloss + loss += dist_lambda * distloss loss.backward() @@ -485,6 +494,7 @@ def train(self): self.writer.add_scalar("train/depthloss", depthloss.item(), step) if cfg.normal_loss: self.writer.add_scalar("train/normalloss", normalloss.item(), step) + if cfg.dist_loss: self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() @@ -702,7 +712,7 @@ def eval(self, step: int): torch.cuda.synchronize() tic = time.time() - colors, _, info = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -710,16 +720,22 @@ def eval(self, step: int): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - ) # [1, H, W, 3] + render_mode="RGB+ED", + ) # [1, H, W, 4] torch.cuda.synchronize() ellipse_time += time.time() - tic - - colors = torch.clamp(colors, 0.0, 1.0) + + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [pixels, colors] if self.cfg.normal_loss: - surf_normals = info["surf_normal"] * 0.5 + 0.5 - rend_normals = info["rend_normal"] * 0.5 + 0.5 - canvas_list.extend([surf_normals, rend_normals]) + normals = info["rend_normal"] * 0.5 + 0.5 + canvas_list.extend([normals]) + + depths = info["rend_depth"] + normals_depth = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) + normals_depth = normals_depth * (alphas).detach() + normals_depth = normals_depth * 0.5 + 0.5 + canvas_list.extend([normals_depth]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -798,9 +814,8 @@ def render_traj(self, step: int): depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths.repeat(1, 1, 1, 3)) if self.cfg.normal_loss: - surf_normals = info["surf_normal"] * 0.5 + 0.5 - rend_normals = info["rend_normal"] * 0.5 + 0.5 - canvas_list.extend([surf_normals, rend_normals]) + normals = info["rend_normal"] * 0.5 + 0.5 + canvas_list.extend([normals]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/point_utils.py b/gsplat/point_utils.py index 1b1269d17..afa9d16c2 100644 --- a/gsplat/point_utils.py +++ b/gsplat/point_utils.py @@ -2,9 +2,10 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np +import math -def depths_to_points(depthmap, world_view_transform, full_proj_transform): +def _depths_to_points(depthmap, world_view_transform, full_proj_transform): c2w = (world_view_transform.T).inverse() H, W = depthmap.shape[:2] ndc2pix = ( @@ -30,8 +31,8 @@ def depths_to_points(depthmap, world_view_transform, full_proj_transform): return points -def depth_to_normal(depth, world_view_transform, full_proj_transform): - points = depths_to_points(depth, world_view_transform, full_proj_transform).reshape( +def _depth_to_normal(depth, world_view_transform, full_proj_transform): + points = _depths_to_points(depth, world_view_transform, full_proj_transform).reshape( *depth.shape[:2], 3 ) output = torch.zeros_like(points) @@ -40,3 +41,46 @@ def depth_to_normal(depth, world_view_transform, full_proj_transform): normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) output[1:-1, 1:-1, :] = normal_map return output + +def depth_to_normal(depths, camtoworlds, Ks, near_plane, far_plane): + height, width = depths.shape[1:3] + viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] + + normals = [] + for cid, depth in enumerate(depths): + FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) + FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) + world_view_transform = viewmats[cid].transpose(0, 1) + projection_matrix = _getProjectionMatrix( + znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device + ).transpose(0, 1) + full_proj_transform = ( + world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) + ).squeeze(0) + normal = _depth_to_normal(depth, world_view_transform, full_proj_transform) + normals.append(normal) + normals = torch.stack(normals, dim=0) + return normals + + +def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4, device=device) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 832f36fd0..59139c2b9 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -401,6 +401,12 @@ def rasterization( "height": height, "tile_size": tile_size, } + + depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) + normals_depth = depth_to_normal(depths_expected, camtoworlds, Ks, near_plane, far_plane) + normals_depth = normals_depth * (render_alphas).detach() + meta["rend_depth"] = depths_expected + meta["rend_normal"] = normals_depth return render_colors, render_alphas, meta @@ -781,15 +787,9 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - # assume the depth points form the 'surface' and generate psudo surface normal for regularizations. - surf_normal = depth_to_normal(surf_depth, world_view_transform, full_proj_transform) - # remember to multiply with accum_alpha since render_normal is unnormalized. - surf_normal = surf_normal * (render_alphas).detach() - render_alphas = render_alphas.unsqueeze(0) meta = { - # "surf_depth": surf_depth.unsqueeze(0), - "surf_normal": surf_normal.unsqueeze(0), + "rend_depth": surf_depth.unsqueeze(0), "rend_normal": render_normal.unsqueeze(0), "rend_dist": render_dist.unsqueeze(0), } From a8fd7a9ca71a79b2b666275c6662a7ea8e962d15 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 4 Jul 2024 16:05:16 -0700 Subject: [PATCH 10/66] 3dgs normal working --- examples/simple_trainer_mcmc.py | 19 ++++++++++----- gsplat/cuda/_wrapper.py | 7 +++--- gsplat/cuda/csrc/bindings.h | 3 ++- .../cuda/csrc/fully_fused_projection_bwd.cu | 1 + .../cuda/csrc/fully_fused_projection_fwd.cu | 23 +++++++++++++++---- gsplat/cuda/csrc/utils.cuh | 20 ++++++++++++++++ gsplat/rendering.py | 10 ++++---- 7 files changed, 63 insertions(+), 20 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 95aa28926..e50af4ed8 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -144,7 +144,7 @@ class Config: # Enable normal loss. (experimental) normal_loss: bool = True # Weight for normal loss - normal_lambda: float = 0.0#5 + normal_lambda: float = 0.05 # Enable distortion loss. (experimental) dist_loss: bool = False @@ -414,10 +414,11 @@ def train(self): image_ids=image_ids, render_mode="RGB+ED", # if cfg.depth_loss else "RGB", ) - if renders.shape[-1] == 4: - colors, depths = renders[..., 0:3], renders[..., 3:4] - else: - colors, depths = renders, None + colors, _, depths = renders[..., 0:3], renders[..., 3:6], renders[..., 6:7] + # if renders.shape[-1] == 4: + # colors, depths = renders[..., 0:3], renders[..., 3:4] + # else: + # colors, depths = renders, None if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) @@ -796,7 +797,7 @@ def render_traj(self, step: int): canvas_all = [] for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): - renders, _, info = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds[i : i + 1], Ks=K[None], width=width, @@ -816,6 +817,12 @@ def render_traj(self, step: int): if self.cfg.normal_loss: normals = info["rend_normal"] * 0.5 + 0.5 canvas_list.extend([normals]) + + depths = info["rend_depth"] + normals_depth = depth_to_normal(depths, camtoworlds, K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) + normals_depth = normals_depth * (alphas).detach() + normals_depth = normals_depth * 0.5 + 0.5 + canvas_list.extend([normals_depth]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 9043181ce..875f40338 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -709,7 +709,7 @@ def forward( calc_compensations: bool, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: # "covars" and {"quats", "scales"} are mutually exclusive - radii, means2d, depths, conics, compensations = _make_lazy_cuda_func( + radii, means2d, depths, normals, conics, compensations = _make_lazy_cuda_func( "fully_fused_projection_fwd" )( means, @@ -735,10 +735,10 @@ def forward( ctx.height = height ctx.eps2d = eps2d - return radii, means2d, depths, conics, compensations + return radii, means2d, depths, normals, conics, compensations @staticmethod - def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): + def backward(ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations): ( means, covars, @@ -772,6 +772,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): compensations, v_means2d.contiguous(), v_depths.contiguous(), + v_normals.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index c983f461e..bde378773 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -65,7 +65,7 @@ world_to_cam_bwd_tensor(const torch::Tensor &means, // [N, 3] const bool means_requires_grad, const bool covars_requires_grad, const bool viewmats_requires_grad); -std::tuple +std::tuple fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] optional @@ -94,6 +94,7 @@ fully_fused_projection_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [C, N, 2] const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional const bool viewmats_requires_grad); diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 0f8e16ba5..78748b176 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -201,6 +201,7 @@ fully_fused_projection_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [C, N, 2] const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional const bool viewmats_requires_grad) { diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 53b3d3388..32db3ed48 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -31,6 +31,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, int32_t *__restrict__ radii, // [C, N] T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] + T *__restrict__ normals, // [C, N, 3] T *__restrict__ conics, // [C, N, 3] T *__restrict__ compensations // [C, N] optional ) { @@ -46,6 +47,8 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, means += gid * 3; viewmats += cid * 16; Ks += cid * 9; + quats += gid * 4; + scales += gid * 3; // glm is column-major but input is row-major mat3 R = mat3(viewmats[0], viewmats[4], viewmats[8], // 1st column @@ -72,8 +75,6 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, ); } else { // compute from quaternions and scales - quats += gid * 4; - scales += gid * 3; quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); } mat3 covar_c; @@ -115,11 +116,22 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, return; } + glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + glm::mat3 S = scale_to_mat(glm::make_vec3(scales), 1.f); + glm::mat3 L = rotmat * S; + // float3 normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmats); + // write to outputs radii[idx] = (int32_t)radius; means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; depths[idx] = mean_c.z; + // normals[idx * 3] = normal.x; + // normals[idx * 3 + 1] = normal.y; + // normals[idx * 3 + 2] = normal.z; + normals[idx * 3] = L[2].x; + normals[idx * 3 + 1] = L[2].y; + normals[idx * 3 + 2] = L[2].z; conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; conics[idx * 3 + 2] = covar2d_inv[1][1]; @@ -129,7 +141,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, } -std::tuple +std::tuple fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] optional @@ -159,6 +171,7 @@ fully_fused_projection_fwd_tensor( torch::Tensor radii = torch::empty({C, N}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); torch::Tensor depths = torch::empty({C, N}, means.options()); + torch::Tensor normals = torch::empty({C, N, 3}, means.options()); torch::Tensor conics = torch::empty({C, N, 3}, means.options()); torch::Tensor compensations; if (calc_compensations) { @@ -173,9 +186,9 @@ fully_fused_projection_fwd_tensor( scales.has_value() ? scales.value().data_ptr() : nullptr, viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, eps2d, near_plane, far_plane, radius_clip, radii.data_ptr(), - means2d.data_ptr(), depths.data_ptr(), + means2d.data_ptr(), depths.data_ptr(), normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr); } - return std::make_tuple(radii, means2d, depths, conics, compensations); + return std::make_tuple(radii, means2d, depths, normals, conics, compensations); } diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 57751315f..f705ef0b9 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -343,4 +343,24 @@ inline __device__ void add_blur_vjp(const T eps2d, const mat2 conic_blur, v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); } +__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) +{ + float3 transformed = { + matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, + matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, + matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, + }; + return transformed; +} + +inline __device__ glm::mat3 +scale_to_mat(const glm::vec3 scale, const float glob_scale) { + glm::mat3 S = glm::mat3(1.f); + S[0][0] = glob_scale * scale.x; + S[1][1] = glob_scale * scale.y; + S[2][2] = glob_scale * scale.z; + return S; +} + + #endif // GSPLAT_CUDA_UTILS_H \ No newline at end of file diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 59139c2b9..ea79d0f91 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -249,7 +249,7 @@ def rasterization( opacities = opacities[gaussian_ids] # [nnz] else: # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. - radii, means2d, depths, conics, compensations = proj_results + radii, means2d, depths, normals, conics, compensations = proj_results opacities = opacities.repeat(C, 1) # [C, N] camera_ids, gaussian_ids = None, None @@ -318,7 +318,7 @@ def rasterization( # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: - colors = torch.cat((colors, depths[..., None]), dim=-1) + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 @@ -403,10 +403,10 @@ def rasterization( } depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) - normals_depth = depth_to_normal(depths_expected, camtoworlds, Ks, near_plane, far_plane) - normals_depth = normals_depth * (render_alphas).detach() + render_normals = render_colors[..., 3:6] + render_normals = F.normalize(render_normals, dim=-1) meta["rend_depth"] = depths_expected - meta["rend_normal"] = normals_depth + meta["rend_normal"] = render_normals return render_colors, render_alphas, meta From 5870b254e4c6131974f8b5564aedbe5b46331200 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 4 Jul 2024 17:45:40 -0700 Subject: [PATCH 11/66] baseline no gradient --- examples/simple_trainer_mcmc.py | 2 +- .../cuda/csrc/fully_fused_projection_bwd.cu | 10 +++++++++- .../cuda/csrc/fully_fused_projection_fwd.cu | 12 +++--------- gsplat/cuda/csrc/utils.cuh | 19 ------------------- 4 files changed, 13 insertions(+), 30 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index e50af4ed8..56ff20700 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -819,7 +819,7 @@ def render_traj(self, step: int): canvas_list.extend([normals]) depths = info["rend_depth"] - normals_depth = depth_to_normal(depths, camtoworlds, K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) + normals_depth = depth_to_normal(depths, camtoworlds[i : i + 1], K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) normals_depth = normals_depth * (alphas).detach() normals_depth = normals_depth * 0.5 + 0.5 canvas_list.extend([normals_depth]) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 78748b176..600522611 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -33,6 +33,7 @@ __global__ void fully_fused_projection_bwd_kernel( // grad outputs const T *__restrict__ v_means2d, // [C, N, 2] const T *__restrict__ v_depths, // [C, N] + const T *__restrict__ v_normals, // [C, N, 3] const T *__restrict__ v_conics, // [C, N, 3] const T *__restrict__ v_compensations, // [C, N] optional // grad inputs @@ -59,6 +60,7 @@ __global__ void fully_fused_projection_bwd_kernel( v_means2d += idx * 2; v_depths += idx; + v_normals += idx * 3; v_conics += idx * 3; // vjp: compute the inverse of the 2d covariance @@ -152,6 +154,11 @@ __global__ void fully_fused_projection_bwd_kernel( vec4 v_quat(0.f); vec3 v_scale(0.f); quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // // add contribution from v_normals + // mat3 v_R = mat3(0, 0, v_normals[0], 0, 0, v_normals[1], 0, 0, v_normals[2]); + // quat_to_rotmat_vjp(quat, v_R, v_quat); + warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); if (warp_group_g.thread_rank() == 0) { @@ -220,6 +227,7 @@ fully_fused_projection_bwd_tensor( CHECK_INPUT(conics); CHECK_INPUT(v_means2d); CHECK_INPUT(v_depths); + CHECK_INPUT(v_normals); CHECK_INPUT(v_conics); if (compensations.has_value()) { CHECK_INPUT(compensations.value()); @@ -256,7 +264,7 @@ fully_fused_projection_bwd_tensor( compensations.has_value() ? compensations.value().data_ptr() : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), - v_conics.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() : nullptr, v_means.data_ptr(), diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 32db3ed48..cb93fb198 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -117,21 +117,15 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, } glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); - glm::mat3 S = scale_to_mat(glm::make_vec3(scales), 1.f); - glm::mat3 L = rotmat * S; - // float3 normal = transformVec4x3({L[2].x, L[2].y, L[2].z}, viewmats); // write to outputs radii[idx] = (int32_t)radius; means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; depths[idx] = mean_c.z; - // normals[idx * 3] = normal.x; - // normals[idx * 3 + 1] = normal.y; - // normals[idx * 3 + 2] = normal.z; - normals[idx * 3] = L[2].x; - normals[idx * 3 + 1] = L[2].y; - normals[idx * 3 + 2] = L[2].z; + normals[idx * 3] = rotmat[2].x; + normals[idx * 3 + 1] = rotmat[2].y; + normals[idx * 3 + 2] = rotmat[2].z; conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; conics[idx * 3 + 2] = covar2d_inv[1][1]; diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index f705ef0b9..9fd7118db 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -343,24 +343,5 @@ inline __device__ void add_blur_vjp(const T eps2d, const mat2 conic_blur, v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); } -__forceinline__ __device__ float3 transformVec4x3(const float3& p, const float* matrix) -{ - float3 transformed = { - matrix[0] * p.x + matrix[4] * p.y + matrix[8] * p.z, - matrix[1] * p.x + matrix[5] * p.y + matrix[9] * p.z, - matrix[2] * p.x + matrix[6] * p.y + matrix[10] * p.z, - }; - return transformed; -} - -inline __device__ glm::mat3 -scale_to_mat(const glm::vec3 scale, const float glob_scale) { - glm::mat3 S = glm::mat3(1.f); - S[0][0] = glob_scale * scale.x; - S[1][1] = glob_scale * scale.y; - S[2][2] = glob_scale * scale.z; - return S; -} - #endif // GSPLAT_CUDA_UTILS_H \ No newline at end of file From a8adfe56aa37d590bd11d3eea1c48bd047927cb2 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Thu, 4 Jul 2024 19:31:19 -0700 Subject: [PATCH 12/66] normal backprob --- .../cuda/csrc/fully_fused_projection_bwd.cu | 7 +--- gsplat/cuda/csrc/utils.cuh | 40 +++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 600522611..928367106 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -153,11 +153,8 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - - // // add contribution from v_normals - // mat3 v_R = mat3(0, 0, v_normals[0], 0, 0, v_normals[1], 0, 0, v_normals[2]); - // quat_to_rotmat_vjp(quat, v_R, v_quat); + // quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + quat_scale_to_covar_vjp_normal(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 9fd7118db..807d506ec 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -106,6 +106,46 @@ inline __device__ void quat_scale_to_covar_vjp( v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; } +template +inline __device__ void quat_scale_to_covar_vjp_normal( + // fwd inputs + const vec4 quat, const vec3 scale, + // precompute + const mat3 R, + // grad outputs + const mat3 v_covar, + const vec3 v_normal, + // grad inputs + vec4 &v_quat, vec3 &v_scale) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T sx = scale[0], sy = scale[1], sz = scale[2]; + + // M = R * S + mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); + mat3 M = R * S; + + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + // so + // for D = M * Mt, + // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M + mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; + mat3 v_R = v_M * S; + + // add contribution from v_normal + // printf("v_normal: %.8f, %.8f, %.8f \n", v_normal.x, v_normal.y, v_normal.z); + v_R[2] += v_normal; + + // grad for (quat, scale) from covar + quat_to_rotmat_vjp(quat, v_R, v_quat); + + v_scale[0] += R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; + v_scale[1] += R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; + v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; +} + + template inline __device__ void quat_scale_to_preci_vjp( // fwd inputs From 08637cb45cb35ef985ff51e917bf82f0a97b8687 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 11:01:05 -0700 Subject: [PATCH 13/66] cleanup --- examples/benchmark_mcmc.sh | 7 +-- examples/simple_trainer_mcmc.py | 95 ++++++++++++++------------------- 2 files changed, 43 insertions(+), 59 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index fdd5f35b6..f065860a6 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,14 +1,12 @@ -RESULT_DIR=results/3dgs_mcmc_sfm +RESULT_DIR=results/2dgs_mcmc_sfm # for SCENE in bicycle bonsai counter garden kitchen room stump; for SCENE in garden treehill bonsai counter kitchen room bicycle stump flowers; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then DATA_FACTOR=4 - DIST_LAMBDA=100 else DATA_FACTOR=2 - DIST_LAMBDA=1000 fi if [ "$SCENE" = "bonsai" ]; then @@ -27,8 +25,7 @@ do # train without eval python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 3dgs \ - --dist_lambda $DIST_LAMBDA \ + --model_type 2dgs \ --init_type sfm \ --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 56ff20700..be2175bac 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -141,15 +141,12 @@ class Config: # Weight for depth loss depth_lambda: float = 1e-2 - # Enable normal loss. (experimental) - normal_loss: bool = True - # Weight for normal loss - normal_lambda: float = 0.05 - - # Enable distortion loss. (experimental) - dist_loss: bool = False - # Weight for distortion loss. Use 100 for unbounded scenes and 1000 for bounded scenes - dist_lambda: float = 100 + # Enable normal consistency loss. (experimental) + normal_consistency_loss: bool = True + # Weight for normal consistency loss + normal_consistency_lambda: float = 0.05 + # Start refining GSs after this iteration + normal_consistency_start_iter: int = 500 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -180,6 +177,10 @@ def __init__(self, cfg: Config) -> None: self.rasterization_fn = rasterization_2dgs_inria_wrapper else: raise ValueError(f"Unsupported model type: {cfg.model_type}") + + self.render_mode = "RGB" + if cfg.depth_loss or cfg.normal_consistency_loss: + self.render_mode = "RGB+ED" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -412,13 +413,11 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED", # if cfg.depth_loss else "RGB", + render_mode=self.render_mode, ) - colors, _, depths = renders[..., 0:3], renders[..., 3:6], renders[..., 6:7] - # if renders.shape[-1] == 4: - # colors, depths = renders[..., 0:3], renders[..., 3:4] - # else: - # colors, depths = renders, None + colors = renders[..., 0:3] + depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None + normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) @@ -456,20 +455,12 @@ def train(self): depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda - if cfg.normal_loss: - normal_lambda = cfg.normal_lambda if step > 7000 else 0.0 - normals = info["rend_normal"] - depths = info["rend_depth"] - normals_depth = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) - normals_depth = normals_depth * (alphas).detach() + if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: + normals_surf = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) + # normals_surf = normals_surf * (alphas).detach() - normalloss = (1 - (normals * normals_depth).sum(dim=-1)).mean() - loss += normal_lambda * normalloss - - if cfg.dist_loss: - dist_lambda = cfg.dist_lambda if step > 3000 else 0.0 - distloss = (info["rend_dist"]).mean() - loss += dist_lambda * distloss + normalconsistencyloss = (1 - (normals * normals_surf).sum(dim=-1)).mean() + loss += cfg.normal_consistency_lambda * normalconsistencyloss loss.backward() @@ -493,10 +484,8 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) - if cfg.normal_loss: - self.writer.add_scalar("train/normalloss", normalloss.item(), step) - if cfg.dist_loss: - self.writer.add_scalar("train/distloss", distloss.item(), step) + if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: + self.writer.add_scalar("train/normalconsistencyloss", normalconsistencyloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) @@ -721,22 +710,20 @@ def eval(self, step: int): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - render_mode="RGB+ED", + render_mode=self.render_mode, ) # [1, H, W, 4] torch.cuda.synchronize() ellipse_time += time.time() - tic colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [pixels, colors] - if self.cfg.normal_loss: - normals = info["rend_normal"] * 0.5 + 0.5 - canvas_list.extend([normals]) - - depths = info["rend_depth"] - normals_depth = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) - normals_depth = normals_depth * (alphas).detach() - normals_depth = normals_depth * 0.5 + 0.5 - canvas_list.extend([normals_depth]) + if self.cfg.normal_consistency_loss: + depths = renders[..., 3:4] + normals = renders[..., 4:7] + normals_surf = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) + # normals_surf = normals_surf * (alphas).detach() + canvas_list.extend([normals * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -805,24 +792,24 @@ def render_traj(self, step: int): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - render_mode="RGB+ED", - ) # [1, H, W, 4] + render_mode=self.render_mode, + ) # [1, H, W, C] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [colors] - if renders.shape[-1] == 4: + + if renders.shape[-1] >= 4: depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths.repeat(1, 1, 1, 3)) - if self.cfg.normal_loss: - normals = info["rend_normal"] * 0.5 + 0.5 - canvas_list.extend([normals]) - - depths = info["rend_depth"] - normals_depth = depth_to_normal(depths, camtoworlds[i : i + 1], K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) - normals_depth = normals_depth * (alphas).detach() - normals_depth = normals_depth * 0.5 + 0.5 - canvas_list.extend([normals_depth]) + + if renders.shape[-1] >= 5: + depths = renders[..., 3:4] + normals = renders[..., 4:7] + normals_surf = depth_to_normal(depths, camtoworlds[i : i + 1], K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) + # normals_surf = normals_surf * (alphas).detach() + canvas_list.extend([normals * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() From ea6ce1a973f7e9615fcc76520efd230faf48a1a7 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 11:01:17 -0700 Subject: [PATCH 14/66] cleanup --- gsplat/rendering.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index ea79d0f91..79d770895 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -318,7 +318,7 @@ def rasterization( # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: - colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + colors = torch.cat((colors, depths[..., None], normals), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 @@ -377,8 +377,9 @@ def rasterization( # normalize the accumulated depth to get the expected depth render_colors = torch.cat( [ - render_colors[..., :-1], - render_colors[..., -1:] / render_alphas.clamp(min=1e-10), + render_colors[..., :3], + render_colors[..., 3:4] / render_alphas.clamp(min=1e-10), + F.normalize(render_colors[..., 4:7], dim=-1), ], dim=-1, ) @@ -401,12 +402,6 @@ def rasterization( "height": height, "tile_size": tile_size, } - - depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) - render_normals = render_colors[..., 3:6] - render_normals = F.normalize(render_normals, dim=-1) - meta["rend_depth"] = depths_expected - meta["rend_normal"] = render_normals return render_colors, render_alphas, meta From f7f0b9c138ea5dadd132f2d382f045f7fb59b4f5 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 11:27:23 -0700 Subject: [PATCH 15/66] cleanup --- examples/simple_trainer_mcmc.py | 91 ++++++++++++++++++++++++--------- gsplat/cuda/_wrapper.py | 4 +- gsplat/point_utils.py | 9 ++-- gsplat/rendering.py | 35 ++++--------- 4 files changed, 86 insertions(+), 53 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index be2175bac..ce93a4333 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -64,9 +64,13 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 15_000, 30_000]) + eval_steps: List[int] = field( + default_factory=lambda: [1_000, 7_000, 15_000, 30_000] + ) # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 15_000, 30_000]) + save_steps: List[int] = field( + default_factory=lambda: [1_000, 7_000, 15_000, 30_000] + ) # Initialization strategy init_type: str = "sfm" @@ -145,8 +149,15 @@ class Config: normal_consistency_loss: bool = True # Weight for normal consistency loss normal_consistency_lambda: float = 0.05 - # Start refining GSs after this iteration - normal_consistency_start_iter: int = 500 + # Start applying normal consistency loss after this iteration + normal_consistency_start_iter: int = 7000 + + # Distoration loss. (experimental) + dist_loss: bool = True + # Weight for distortion loss + dist_lambda: float = 100 + # Start applying distortion loss after this iteration + dist_start_iter: int = 3000 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -177,7 +188,7 @@ def __init__(self, cfg: Config) -> None: self.rasterization_fn = rasterization_2dgs_inria_wrapper else: raise ValueError(f"Unsupported model type: {cfg.model_type}") - + self.render_mode = "RGB" if cfg.depth_loss or cfg.normal_consistency_loss: self.render_mode = "RGB+ED" @@ -456,11 +467,22 @@ def train(self): loss += depthloss * cfg.depth_lambda if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: - normals_surf = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) - # normals_surf = normals_surf * (alphas).detach() - - normalconsistencyloss = (1 - (normals * normals_surf).sum(dim=-1)).mean() - loss += cfg.normal_consistency_lambda * normalconsistencyloss + normals_surf = depth_to_normal( + depths, + camtoworlds, + Ks, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) + normals_surf = normals_surf * (alphas).detach() + normalconsistencyloss = ( + 1 - (normals * normals_surf).sum(dim=-1) + ).mean() + loss += normalconsistencyloss * cfg.normal_consistency_lambda + + if cfg.dist_loss and step > cfg.dist_start_iter: + distloss = info["render_distloss"].mean() + loss += distloss * cfg.dist_lambda loss.backward() @@ -484,8 +506,17 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) - if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: - self.writer.add_scalar("train/normalconsistencyloss", normalconsistencyloss.item(), step) + if ( + cfg.normal_consistency_loss + and step > cfg.normal_consistency_start_iter + ): + self.writer.add_scalar( + "train/normalconsistencyloss", + normalconsistencyloss.item(), + step, + ) + if cfg.dist_loss and step > cfg.dist_start_iter: + self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) @@ -711,17 +742,23 @@ def eval(self, step: int): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode=self.render_mode, - ) # [1, H, W, 4] + ) # [1, H, W, C] torch.cuda.synchronize() ellipse_time += time.time() - tic - + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [pixels, colors] if self.cfg.normal_consistency_loss: - depths = renders[..., 3:4] - normals = renders[..., 4:7] - normals_surf = depth_to_normal(depths, camtoworlds, Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane) - # normals_surf = normals_surf * (alphas).detach() + depths = renders[..., 3:4] # [1, H, W, 1] + normals = renders[..., 4:7] # [1, H, W, 3] + normals_surf = depth_to_normal( + depths, + camtoworlds, + Ks, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) + normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) @@ -797,17 +834,23 @@ def render_traj(self, step: int): colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] canvas_list = [colors] - + if renders.shape[-1] >= 4: depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths.repeat(1, 1, 1, 3)) - + if renders.shape[-1] >= 5: - depths = renders[..., 3:4] - normals = renders[..., 4:7] - normals_surf = depth_to_normal(depths, camtoworlds[i : i + 1], K[None], near_plane=cfg.near_plane, far_plane=cfg.far_plane) - # normals_surf = normals_surf * (alphas).detach() + depths = renders[..., 3:4] # [1, H, W, 1] + normals = renders[..., 4:7] # [1, H, W, 3] + normals_surf = depth_to_normal( + depths, + camtoworlds[i : i + 1], + K[None], + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) + normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 875f40338..1dbadf9c9 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -738,7 +738,9 @@ def forward( return radii, means2d, depths, normals, conics, compensations @staticmethod - def backward(ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations): + def backward( + ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations + ): ( means, covars, diff --git a/gsplat/point_utils.py b/gsplat/point_utils.py index afa9d16c2..d8c30ed92 100644 --- a/gsplat/point_utils.py +++ b/gsplat/point_utils.py @@ -32,9 +32,9 @@ def _depths_to_points(depthmap, world_view_transform, full_proj_transform): def _depth_to_normal(depth, world_view_transform, full_proj_transform): - points = _depths_to_points(depth, world_view_transform, full_proj_transform).reshape( - *depth.shape[:2], 3 - ) + points = _depths_to_points( + depth, world_view_transform, full_proj_transform + ).reshape(*depth.shape[:2], 3) output = torch.zeros_like(points) dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) @@ -42,10 +42,11 @@ def _depth_to_normal(depth, world_view_transform, full_proj_transform): output[1:-1, 1:-1, :] = normal_map return output + def depth_to_normal(depths, camtoworlds, Ks, near_plane, far_plane): height, width = depths.shape[1:3] viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] - + normals = [] for cid, depth in enumerate(depths): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 79d770895..298c5f43d 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -645,7 +645,7 @@ def rasterization_2dgs_inria_wrapper( **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. - + Install the rasterization backend from https://github.com/hbb1/diff-surfel-rasterization """ @@ -753,39 +753,26 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors = torch.stack(render_colors, dim=0) # additional maps - allmap = allmap.permute(1, 2, 0) # [H, W, C] + allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] + render_depth_expected = allmap[..., 0:1] render_alphas = allmap[..., 1:2] - - # get normal map - # transform normal from view space to world space render_normal = allmap[..., 2:5] - render_normal = render_normal @ (world_view_transform[:3, :3].T) - - # get median depth map render_depth_median = allmap[..., 5:6] - render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) + render_dist = allmap[..., 6:7] - # get expected depth map - render_depth_expected = allmap[..., 0:1] + render_normal = render_normal @ (world_view_transform[:3, :3].T) render_depth_expected = render_depth_expected / render_alphas render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) + render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) - # get depth distortion map - render_dist = allmap[..., 6:7] - - # psedo surface attributes - # surf depth is either median or expected by setting depth_ratio to 1 or 0 + # render_depth is either median or expected by setting depth_ratio to 1 or 0 # for bounded scene, use median depth, i.e., depth_ratio = 1; - # for unbounded scene, use expected depth, i.e., depth_ration = 0, to reduce disk anliasing. + # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. depth_ratio = 0 - surf_depth = ( + render_depth = ( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - render_alphas = render_alphas.unsqueeze(0) - meta = { - "rend_depth": surf_depth.unsqueeze(0), - "rend_normal": render_normal.unsqueeze(0), - "rend_dist": render_dist.unsqueeze(0), - } + meta = {"render_distloss": render_dist} + render_colors = torch.cat([render_colors, render_depth, render_normal], dim=-1) return render_colors, render_alphas, meta From 468c451d55e7442b8c374496f8054251ba4bd2f5 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 13:00:51 -0700 Subject: [PATCH 16/66] cmocean dense --- examples/simple_trainer.py | 4 +- examples/simple_trainer_mcmc.py | 74 ++++++++++++++-------- gsplat/color_utils.py | 32 ++++++++++ gsplat/{point_utils.py => normal_utils.py} | 0 gsplat/rendering.py | 1 - 5 files changed, 82 insertions(+), 29 deletions(-) create mode 100644 gsplat/color_utils.py rename gsplat/{point_utils.py => normal_utils.py} (100%) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index ff7af003e..2cf22fe02 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -184,9 +184,7 @@ def create_splats_with_optimizers( # Initialize the GS size to be the average dist of the 3 nearest neighbors dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] dist_avg = torch.sqrt(dist2_avg) - scales = (dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] - # scales[:, 2] /= 100 - scales = torch.log(scales) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] quats = torch.rand((N, 4)) # [N, 4] opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index ce93a4333..082a1a406 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -27,7 +27,8 @@ from gsplat import quat_scale_to_covar_preci from gsplat.rendering import rasterization, rasterization_2dgs_inria_wrapper from gsplat.relocation import compute_relocation -from gsplat.point_utils import depth_to_normal +from gsplat.normal_utils import depth_to_normal +from gsplat.color_utils import apply_float_colormap from simple_trainer import create_splats_with_optimizers @@ -157,7 +158,7 @@ class Config: # Weight for distortion loss dist_lambda: float = 100 # Start applying distortion loss after this iteration - dist_start_iter: int = 3000 + dist_start_iter: int = 0 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -746,11 +747,15 @@ def eval(self, step: int): torch.cuda.synchronize() ellipse_time += time.time() - tic - colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None + normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None + canvas_list = [pixels, colors] + if self.cfg.depth_loss: + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) if self.cfg.normal_consistency_loss: - depths = renders[..., 3:4] # [1, H, W, 1] - normals = renders[..., 4:7] # [1, H, W, 3] normals_surf = depth_to_normal( depths, camtoworlds, @@ -761,6 +766,14 @@ def eval(self, step: int): normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) + if self.cfg.dist_loss: + distloss = info["render_distloss"] + distloss = (distloss - distloss.min()) / ( + distloss.max() - distloss.min() + ) + canvas_list.append( + apply_float_colormap(1 - distloss, colormap="cmo.dense") + ) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -805,25 +818,30 @@ def render_traj(self, step: int): cfg = self.cfg device = self.device - camtoworlds = self.parser.camtoworlds[5:-5] - camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] - camtoworlds = np.concatenate( + camtoworlds_all = self.parser.camtoworlds[5:-5] + camtoworlds_all = generate_interpolated_path(camtoworlds_all, 1) # [N, 3, 4] + camtoworlds_all = np.concatenate( [ - camtoworlds, - np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), ], axis=1, ) # [N, 4, 4] - camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) width, height = list(self.parser.imsize_dict.values())[0] canvas_all = [] - for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + renders, alphas, info = self.rasterize_splats( - camtoworlds=camtoworlds[i : i + 1], - Ks=K[None], + camtoworlds=camtoworlds, + Ks=Ks, width=width, height=height, sh_degree=cfg.sh_degree, @@ -832,27 +850,33 @@ def render_traj(self, step: int): render_mode=self.render_mode, ) # [1, H, W, C] - colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] - canvas_list = [colors] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None + normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None - if renders.shape[-1] >= 4: - depths = renders[..., 3:4] # [1, H, W, 1] + canvas_list = [colors] + if self.cfg.depth_loss: depths = (depths - depths.min()) / (depths.max() - depths.min()) - canvas_list.append(depths.repeat(1, 1, 1, 3)) - - if renders.shape[-1] >= 5: - depths = renders[..., 3:4] # [1, H, W, 1] - normals = renders[..., 4:7] # [1, H, W, 3] + canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) + if self.cfg.normal_consistency_loss: normals_surf = depth_to_normal( depths, - camtoworlds[i : i + 1], - K[None], + camtoworlds, + Ks, near_plane=cfg.near_plane, far_plane=cfg.far_plane, ) normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) + if self.cfg.dist_loss: + distloss = info["render_distloss"] + distloss = (distloss - distloss.min()) / ( + distloss.max() - distloss.min() + ) + canvas_list.append( + apply_float_colormap(1 - distloss, colormap="cmo.dense") + ) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/color_utils.py b/gsplat/color_utils.py new file mode 100644 index 000000000..a52da6c5d --- /dev/null +++ b/gsplat/color_utils.py @@ -0,0 +1,32 @@ +import matplotlib +from jaxtyping import Float +from torch import Tensor +import cmocean + + +def apply_float_colormap( + image: Float[Tensor, "*bs 1"], colormap: str = "viridis" +) -> Float[Tensor, "*bs rgb=3"]: + """Convert single channel to a color image. + + Args: + image: Single channel image. + colormap: Colormap for image. + + Returns: + Tensor: Colored image with colors in [0, 1] + """ + if colormap == "default": + colormap = "turbo" + + image = torch.nan_to_num(image, 0) + if colormap == "gray": + return image.repeat(1, 1, 3) + image_long = (image * 255).long() + image_long_min = torch.min(image_long) + image_long_max = torch.max(image_long) + assert image_long_min >= 0, f"the min value is {image_long_min}" + assert image_long_max <= 255, f"the max value is {image_long_max}" + return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[ + :, :3 + ][image_long[..., 0]] diff --git a/gsplat/point_utils.py b/gsplat/normal_utils.py similarity index 100% rename from gsplat/point_utils.py rename to gsplat/normal_utils.py diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 298c5f43d..2702e7c3b 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -13,7 +13,6 @@ rasterize_to_pixels, spherical_harmonics, ) -from .point_utils import depth_to_normal def rasterization( From 8cfa0fcc015c639a8c781d4e06dd2811cde17bd3 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 13:27:16 -0700 Subject: [PATCH 17/66] cmo ice --- examples/simple_trainer_mcmc.py | 15 +++++++++------ gsplat/color_utils.py | 1 + gsplat/rendering.py | 8 ++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 082a1a406..248cfd021 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -25,7 +25,11 @@ set_random_seed, ) from gsplat import quat_scale_to_covar_preci -from gsplat.rendering import rasterization, rasterization_2dgs_inria_wrapper +from gsplat.rendering import ( + rasterization, + rasterization_2dgs_inria_wrapper, + rasterization_inria_wrapper, +) from gsplat.relocation import compute_relocation from gsplat.normal_utils import depth_to_normal from gsplat.color_utils import apply_float_colormap @@ -158,7 +162,7 @@ class Config: # Weight for distortion loss dist_lambda: float = 100 # Start applying distortion loss after this iteration - dist_start_iter: int = 0 + dist_start_iter: int = 3000 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -185,6 +189,8 @@ def __init__(self, cfg: Config) -> None: self.device = "cuda" if cfg.model_type == "3dgs": self.rasterization_fn = rasterization + elif cfg.model_type == "3dgs_inria": + self.rasterization_fn = rasterization_inria_wrapper elif cfg.model_type == "2dgs": self.rasterization_fn = rasterization_2dgs_inria_wrapper else: @@ -304,7 +310,6 @@ def rasterize_splats( **kwargs, ) -> Tuple[Tensor, Tensor, Dict]: means = self.splats["means3d"] # [N, 3] - # rasterization does normalization internally quats = self.splats["quats"] # [N, 4] scales = torch.exp(self.splats["scales"]) # [N, 3] opacities = torch.sigmoid(self.splats["opacities"]) # [N,] @@ -771,9 +776,7 @@ def eval(self, step: int): distloss = (distloss - distloss.min()) / ( distloss.max() - distloss.min() ) - canvas_list.append( - apply_float_colormap(1 - distloss, colormap="cmo.dense") - ) + canvas_list.append(apply_float_colormap(distloss, colormap="cmo.ice")) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/color_utils.py b/gsplat/color_utils.py index a52da6c5d..60581116d 100644 --- a/gsplat/color_utils.py +++ b/gsplat/color_utils.py @@ -1,5 +1,6 @@ import matplotlib from jaxtyping import Float +import torch from torch import Tensor import cmocean diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 2702e7c3b..ecc4a2ece 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -555,6 +555,9 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): device = means.device channels = colors.shape[-1] + # rasterization from inria does not normalization internally + quats = F.normalize(quats, dim=-1) # [N, 4] + render_colors = [] for cid in range(C): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) @@ -645,7 +648,7 @@ def rasterization_2dgs_inria_wrapper( ) -> Tuple[Tensor, Tensor, Dict]: """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. - Install the rasterization backend from + Install the 2DGS rasterization backend from https://github.com/hbb1/diff-surfel-rasterization """ from diff_surfel_rasterization import ( @@ -680,8 +683,9 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): device = means.device channels = colors.shape[-1] + # rasterization from inria does not normalization internally + quats = F.normalize(quats, dim=-1) # [N, 4] scales = scales[:, :2] # [N, 2] - quats = quats = F.normalize(quats, dim=-1) # [N, 4] render_colors = [] for cid in range(C): From 20cd888644561e4f7c10f54d333c3e5e0e333e6b Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 13:58:36 -0700 Subject: [PATCH 18/66] voltage --- examples/simple_trainer_mcmc.py | 6 ++++-- gsplat/color_utils.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 248cfd021..453b8299e 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -776,7 +776,9 @@ def eval(self, step: int): distloss = (distloss - distloss.min()) / ( distloss.max() - distloss.min() ) - canvas_list.append(apply_float_colormap(distloss, colormap="cmo.ice")) + canvas_list.append( + apply_float_colormap(distloss, colormap="cmr.voltage") + ) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -878,7 +880,7 @@ def render_traj(self, step: int): distloss.max() - distloss.min() ) canvas_list.append( - apply_float_colormap(1 - distloss, colormap="cmo.dense") + apply_float_colormap(distloss, colormap="cmr.voltage") ) # write images diff --git a/gsplat/color_utils.py b/gsplat/color_utils.py index 60581116d..69d9bedf2 100644 --- a/gsplat/color_utils.py +++ b/gsplat/color_utils.py @@ -2,7 +2,7 @@ from jaxtyping import Float import torch from torch import Tensor -import cmocean +import cmasher def apply_float_colormap( @@ -29,5 +29,5 @@ def apply_float_colormap( assert image_long_min >= 0, f"the min value is {image_long_min}" assert image_long_max <= 255, f"the max value is {image_long_max}" return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[ - :, :3 - ][image_long[..., 0]] + image_long[..., 0] + ] From 503ac93cfa5c2a9a516c2c1b7ddc635ca378c85a Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 14:06:57 -0700 Subject: [PATCH 19/66] edit bash --- examples/benchmark_mcmc.sh | 38 +++------------------------------ examples/simple_trainer_mcmc.py | 9 +++----- 2 files changed, 6 insertions(+), 41 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index f065860a6..e66894a4f 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,7 +1,5 @@ -RESULT_DIR=results/2dgs_mcmc_sfm - # for SCENE in bicycle bonsai counter garden kitchen room stump; -for SCENE in garden treehill bonsai counter kitchen room bicycle stump flowers; +for SCENE in garden treehill; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then DATA_FACTOR=4 @@ -23,42 +21,12 @@ do echo "Running $SCENE" - # train without eval python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ --model_type 2dgs \ --init_type sfm \ + --eval_steps 1000 7000 15000 30000 \ --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ + --result_dir results/2dgs_mcmc_sfm/$SCENE/ - # # run eval and render - # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - # do - # python simple_trainer.py --disable_viewer --data_factor $DATA_FACTOR \ - # --data_dir data/360_v2/$SCENE/ \ - # --result_dir $RESULT_DIR/$SCENE/ \ - # --ckpt $CKPT - # done done - - -for SCENE in bicycle bonsai counter garden kitchen room stump; -do - echo "=== Eval Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/val*; - do - echo $STATS - cat $STATS; - echo - done - - echo "=== Train Stats ===" - - for STATS in $RESULT_DIR/$SCENE/stats/train*; - do - echo $STATS - cat $STATS; - echo - done -done \ No newline at end of file diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 453b8299e..eaa901e09 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -69,13 +69,9 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field( - default_factory=lambda: [1_000, 7_000, 15_000, 30_000] - ) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model - save_steps: List[int] = field( - default_factory=lambda: [1_000, 7_000, 15_000, 30_000] - ) + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Initialization strategy init_type: str = "sfm" @@ -184,6 +180,7 @@ class Runner: def __init__(self, cfg: Config) -> None: set_random_seed(42) + print(cfg.eval_steps) self.cfg = cfg self.device = "cuda" From 23f99849b723061efce63884e8a59aeedbfee513 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 14:54:41 -0700 Subject: [PATCH 20/66] clean up benchmark script --- examples/benchmark_mcmc.sh | 40 +++++++++++++++++++++++++++++---- examples/simple_trainer_mcmc.py | 29 +++++++++++------------- gsplat/cuda/csrc/utils.cuh | 3 +-- gsplat/rendering.py | 6 ++--- 4 files changed, 53 insertions(+), 25 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index e66894a4f..f5ccdbb65 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -20,13 +20,45 @@ do fi echo "Running $SCENE" + EVAL_STEPS="1000 7000 15000 30000" - python simple_trainer_mcmc.py --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 2dgs \ + # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + # --model_type 2dgs_inria \ + # --init_type sfm \ + # --cap_max $CAP_MAX \ + # --data_dir data/360_v2/$SCENE/ \ + # --result_dir results/2dgs_inria/$SCENE/ + + # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + # --model_type 3dgs_inria \ + # --init_type sfm \ + # --cap_max $CAP_MAX \ + # --data_dir data/360_v2/$SCENE/ \ + # --result_dir results/3dgs_inria/$SCENE/ + + # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + # --model_type 3dgs \ + # --init_type sfm \ + # --cap_max $CAP_MAX \ + # --data_dir data/360_v2/$SCENE/ \ + # --result_dir results/3dgs/$SCENE/ + + # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + # --model_type 2dgs_inria \ + # --init_type sfm \ + # --cap_max $CAP_MAX \ + # --data_dir data/360_v2/$SCENE/ \ + # --normal_consistency_loss \ + # --dist_loss \ + # --result_dir results/2dgs_inria_with_normal/$SCENE/ + + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 3dgs \ --init_type sfm \ - --eval_steps 1000 7000 15000 30000 \ --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ - --result_dir results/2dgs_mcmc_sfm/$SCENE/ + --normal_consistency_loss \ + --dist_loss \ + --result_dir results/3dgs_with_normal/$SCENE/ done diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index eaa901e09..8f3911fa4 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -38,7 +38,7 @@ @dataclass class Config: - # Model type can be 3dgs or 2dgs + # Model type can be 3dgs, 2dgs model_type: str = "3dgs" # Disable viewer disable_viewer: bool = False @@ -147,14 +147,14 @@ class Config: depth_lambda: float = 1e-2 # Enable normal consistency loss. (experimental) - normal_consistency_loss: bool = True + normal_consistency_loss: bool = False # Weight for normal consistency loss normal_consistency_lambda: float = 0.05 # Start applying normal consistency loss after this iteration normal_consistency_start_iter: int = 7000 # Distoration loss. (experimental) - dist_loss: bool = True + dist_loss: bool = False # Weight for distortion loss dist_lambda: float = 100 # Start applying distortion loss after this iteration @@ -180,7 +180,6 @@ class Runner: def __init__(self, cfg: Config) -> None: set_random_seed(42) - print(cfg.eval_steps) self.cfg = cfg self.device = "cuda" @@ -188,12 +187,12 @@ def __init__(self, cfg: Config) -> None: self.rasterization_fn = rasterization elif cfg.model_type == "3dgs_inria": self.rasterization_fn = rasterization_inria_wrapper - elif cfg.model_type == "2dgs": + elif cfg.model_type == "2dgs_inria": self.rasterization_fn = rasterization_2dgs_inria_wrapper else: raise ValueError(f"Unsupported model type: {cfg.model_type}") - self.render_mode = "RGB" + self.render_mode = "RGB+ED" if cfg.depth_loss or cfg.normal_consistency_loss: self.render_mode = "RGB+ED" @@ -750,14 +749,13 @@ def eval(self, step: int): ellipse_time += time.time() - tic colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) - depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None - normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None - canvas_list = [pixels, colors] - if self.cfg.depth_loss: + if renders.shape[-1] >= 4: + depths = renders[..., 3:4] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) - if self.cfg.normal_consistency_loss: + if renders.shape[-1] >= 5: + normals = renders[..., 4:7] normals_surf = depth_to_normal( depths, camtoworlds, @@ -853,14 +851,13 @@ def render_traj(self, step: int): ) # [1, H, W, C] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) - depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None - normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None - canvas_list = [colors] - if self.cfg.depth_loss: + if renders.shape[-1] >= 4: + depths = renders[..., 3:4] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) - if self.cfg.normal_consistency_loss: + if renders.shape[-1] >= 5: + normals = renders[..., 4:7] normals_surf = depth_to_normal( depths, camtoworlds, diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 807d506ec..2b772e54d 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -134,7 +134,7 @@ inline __device__ void quat_scale_to_covar_vjp_normal( mat3 v_R = v_M * S; // add contribution from v_normal - // printf("v_normal: %.8f, %.8f, %.8f \n", v_normal.x, v_normal.y, v_normal.z); + // someone should double check this. No idea if this is correct v_R[2] += v_normal; // grad for (quat, scale) from covar @@ -383,5 +383,4 @@ inline __device__ void add_blur_vjp(const T eps2d, const mat2 conic_blur, v_sqr_comp * (one_minus_sqr_comp * conic_blur[1][1] - eps2d * det_conic_blur); } - #endif // GSPLAT_CUDA_UTILS_H \ No newline at end of file diff --git a/gsplat/rendering.py b/gsplat/rendering.py index ecc4a2ece..10066bb15 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -555,7 +555,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): device = means.device channels = colors.shape[-1] - # rasterization from inria does not normalization internally + # rasterization from inria does not do normalization internally quats = F.normalize(quats, dim=-1) # [N, 4] render_colors = [] @@ -607,7 +607,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): _colors.shape[0], 3 - _colors.shape[-1], device=device ) _colors = torch.cat([_colors, pad], dim=-1) - _render_colors_, _ = rasterizer( + _render_colors_, radii = rasterizer( means3D=means, means2D=means2D, shs=_colors if colors.dim() == 3 else None, @@ -683,7 +683,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): device = means.device channels = colors.shape[-1] - # rasterization from inria does not normalization internally + # rasterization from inria does not do normalization internally quats = F.normalize(quats, dim=-1) # [N, 4] scales = scales[:, :2] # [N, 2] From 02e00fbe45ca3e39c7810c1716c12edd5af9eab6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 14:58:46 -0700 Subject: [PATCH 21/66] remove dist_loss --- examples/benchmark_mcmc.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index f5ccdbb65..fddadf028 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -1,4 +1,4 @@ -# for SCENE in bicycle bonsai counter garden kitchen room stump; +# for SCENE in bicycle bonsai counter garden kitchen room stump treehill flowers; for SCENE in garden treehill; do if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then @@ -58,7 +58,6 @@ do --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ --normal_consistency_loss \ - --dist_loss \ --result_dir results/3dgs_with_normal/$SCENE/ done From 96e8d2384bfccaf32c20143bd6c53643afd39806 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 16:16:21 -0700 Subject: [PATCH 22/66] depth must be last --- examples/simple_trainer_mcmc.py | 17 ++++++++++------- gsplat/rendering.py | 4 ++-- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 8f3911fa4..44400d3e8 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -428,9 +428,7 @@ def train(self): image_ids=image_ids, render_mode=self.render_mode, ) - colors = renders[..., 0:3] - depths = renders[..., 3:4] if renders.shape[-1] >= 4 else None - normals = renders[..., 4:7] if renders.shape[-1] >= 5 else None + colors = renders[..., :3] if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) @@ -449,6 +447,7 @@ def train(self): loss += cfg.scale_reg * torch.abs(torch.exp(self.splats["scales"])).mean() if cfg.depth_loss: + depths = renders[..., -1:] # query depths from depth map points = torch.stack( [ @@ -469,6 +468,8 @@ def train(self): loss += depthloss * cfg.depth_lambda if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: + depths = renders[..., -1:] + normals = renders[..., -4:-1] normals_surf = depth_to_normal( depths, camtoworlds, @@ -751,11 +752,12 @@ def eval(self, step: int): colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [pixels, colors] if renders.shape[-1] >= 4: - depths = renders[..., 3:4] + depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) if renders.shape[-1] >= 5: - normals = renders[..., 4:7] + depths = renders[..., -1:] + normals = renders[..., -4:-1] normals_surf = depth_to_normal( depths, camtoworlds, @@ -853,11 +855,12 @@ def render_traj(self, step: int): colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [colors] if renders.shape[-1] >= 4: - depths = renders[..., 3:4] + depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) if renders.shape[-1] >= 5: - normals = renders[..., 4:7] + depths = renders[..., -1:] + normals = renders[..., -4:-1] normals_surf = depth_to_normal( depths, camtoworlds, diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 10066bb15..a4b749271 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -377,8 +377,8 @@ def rasterization( render_colors = torch.cat( [ render_colors[..., :3], - render_colors[..., 3:4] / render_alphas.clamp(min=1e-10), - F.normalize(render_colors[..., 4:7], dim=-1), + F.normalize(render_colors[..., -4:-1], dim=-1), + render_colors[..., -1:] / render_alphas.clamp(min=1e-10), ], dim=-1, ) From f336b9501c3d3064ac837e3b4201b20e2b3bcc5f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 5 Jul 2024 16:29:39 -0700 Subject: [PATCH 23/66] colors normal depth --- examples/simple_trainer_mcmc.py | 12 ++++++------ gsplat/rendering.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 44400d3e8..e082b43c4 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -751,11 +751,11 @@ def eval(self, step: int): colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [pixels, colors] - if renders.shape[-1] >= 4: + if cfg.depth_loss: depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) - if renders.shape[-1] >= 5: + if cfg.normal_consistency_loss: depths = renders[..., -1:] normals = renders[..., -4:-1] normals_surf = depth_to_normal( @@ -768,7 +768,7 @@ def eval(self, step: int): normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) - if self.cfg.dist_loss: + if cfg.dist_loss: distloss = info["render_distloss"] distloss = (distloss - distloss.min()) / ( distloss.max() - distloss.min() @@ -854,11 +854,11 @@ def render_traj(self, step: int): colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [colors] - if renders.shape[-1] >= 4: + if cfg.depth_loss: depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) - if renders.shape[-1] >= 5: + if cfg.normal_consistency_loss: depths = renders[..., -1:] normals = renders[..., -4:-1] normals_surf = depth_to_normal( @@ -871,7 +871,7 @@ def render_traj(self, step: int): normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) - if self.cfg.dist_loss: + if cfg.dist_loss: distloss = info["render_distloss"] distloss = (distloss - distloss.min()) / ( distloss.max() - distloss.min() diff --git a/gsplat/rendering.py b/gsplat/rendering.py index a4b749271..dcba0e852 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -317,7 +317,7 @@ def rasterization( # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: - colors = torch.cat((colors, depths[..., None], normals), dim=-1) + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 @@ -777,5 +777,5 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): ) meta = {"render_distloss": render_dist} - render_colors = torch.cat([render_colors, render_depth, render_normal], dim=-1) + render_colors = torch.cat([render_colors, render_normal, render_depth], dim=-1) return render_colors, render_alphas, meta From 54f137e973faf3e54786c605777de8fbddac0ea6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 09:45:57 -0700 Subject: [PATCH 24/66] reduce diff --- examples/simple_trainer_mcmc.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index e082b43c4..3dd591348 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -467,7 +467,7 @@ def train(self): depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda - if cfg.normal_consistency_loss and step > cfg.normal_consistency_start_iter: + if cfg.normal_consistency_loss: depths = renders[..., -1:] normals = renders[..., -4:-1] normals_surf = depth_to_normal( @@ -481,11 +481,13 @@ def train(self): normalconsistencyloss = ( 1 - (normals * normals_surf).sum(dim=-1) ).mean() - loss += normalconsistencyloss * cfg.normal_consistency_lambda + if step > cfg.normal_consistency_start_iter: + loss += normalconsistencyloss * cfg.normal_consistency_lambda - if cfg.dist_loss and step > cfg.dist_start_iter: + if cfg.dist_loss: distloss = info["render_distloss"].mean() - loss += distloss * cfg.dist_lambda + if step > cfg.dist_start_iter: + loss += distloss * cfg.dist_lambda loss.backward() @@ -509,16 +511,13 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) - if ( - cfg.normal_consistency_loss - and step > cfg.normal_consistency_start_iter - ): + if cfg.normal_consistency_loss: self.writer.add_scalar( "train/normalconsistencyloss", normalconsistencyloss.item(), step, ) - if cfg.dist_loss and step > cfg.dist_start_iter: + if cfg.dist_loss: self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() From 50749debfe5001647dd8dfbe3b38016e5fcfdcc6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 09:50:38 -0700 Subject: [PATCH 25/66] cleanup --- examples/simple_trainer_mcmc.py | 2 +- gsplat/rendering.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 3dd591348..fbec6dc8e 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -192,7 +192,7 @@ def __init__(self, cfg: Config) -> None: else: raise ValueError(f"Unsupported model type: {cfg.model_type}") - self.render_mode = "RGB+ED" + self.render_mode = "RGB" if cfg.depth_loss or cfg.normal_consistency_loss: self.render_mode = "RGB+ED" diff --git a/gsplat/rendering.py b/gsplat/rendering.py index dcba0e852..40d548dc3 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -377,7 +377,7 @@ def rasterization( render_colors = torch.cat( [ render_colors[..., :3], - F.normalize(render_colors[..., -4:-1], dim=-1), + render_colors[..., -4:-1], render_colors[..., -1:] / render_alphas.clamp(min=1e-10), ], dim=-1, From bcac80416a949e13d06a99684def89c36de92e9c Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 09:56:02 -0700 Subject: [PATCH 26/66] remove distloss --- examples/simple_trainer_mcmc.py | 35 ++------------------------------- gsplat/color_utils.py | 33 ------------------------------- 2 files changed, 2 insertions(+), 66 deletions(-) delete mode 100644 gsplat/color_utils.py diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index fbec6dc8e..0f786653c 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -32,7 +32,6 @@ ) from gsplat.relocation import compute_relocation from gsplat.normal_utils import depth_to_normal -from gsplat.color_utils import apply_float_colormap from simple_trainer import create_splats_with_optimizers @@ -153,13 +152,6 @@ class Config: # Start applying normal consistency loss after this iteration normal_consistency_start_iter: int = 7000 - # Distoration loss. (experimental) - dist_loss: bool = False - # Weight for distortion loss - dist_lambda: float = 100 - # Start applying distortion loss after this iteration - dist_start_iter: int = 3000 - # Dump information to tensorboard every this steps tb_every: int = 100 # Save training images to tensorboard @@ -484,11 +476,6 @@ def train(self): if step > cfg.normal_consistency_start_iter: loss += normalconsistencyloss * cfg.normal_consistency_lambda - if cfg.dist_loss: - distloss = info["render_distloss"].mean() - if step > cfg.dist_start_iter: - loss += distloss * cfg.dist_lambda - loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " @@ -517,8 +504,6 @@ def train(self): normalconsistencyloss.item(), step, ) - if cfg.dist_loss: - self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) @@ -753,7 +738,7 @@ def eval(self, step: int): if cfg.depth_loss: depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) - canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) + canvas_list.append(depths) if cfg.normal_consistency_loss: depths = renders[..., -1:] normals = renders[..., -4:-1] @@ -767,14 +752,6 @@ def eval(self, step: int): normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) - if cfg.dist_loss: - distloss = info["render_distloss"] - distloss = (distloss - distloss.min()) / ( - distloss.max() - distloss.min() - ) - canvas_list.append( - apply_float_colormap(distloss, colormap="cmr.voltage") - ) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() @@ -856,7 +833,7 @@ def render_traj(self, step: int): if cfg.depth_loss: depths = renders[..., -1:] depths = (depths - depths.min()) / (depths.max() - depths.min()) - canvas_list.append(apply_float_colormap(1 - depths, colormap="turbo")) + canvas_list.append(depths) if cfg.normal_consistency_loss: depths = renders[..., -1:] normals = renders[..., -4:-1] @@ -870,14 +847,6 @@ def render_traj(self, step: int): normals_surf = normals_surf * (alphas).detach() canvas_list.extend([normals * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) - if cfg.dist_loss: - distloss = info["render_distloss"] - distloss = (distloss - distloss.min()) / ( - distloss.max() - distloss.min() - ) - canvas_list.append( - apply_float_colormap(distloss, colormap="cmr.voltage") - ) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/color_utils.py b/gsplat/color_utils.py deleted file mode 100644 index 69d9bedf2..000000000 --- a/gsplat/color_utils.py +++ /dev/null @@ -1,33 +0,0 @@ -import matplotlib -from jaxtyping import Float -import torch -from torch import Tensor -import cmasher - - -def apply_float_colormap( - image: Float[Tensor, "*bs 1"], colormap: str = "viridis" -) -> Float[Tensor, "*bs rgb=3"]: - """Convert single channel to a color image. - - Args: - image: Single channel image. - colormap: Colormap for image. - - Returns: - Tensor: Colored image with colors in [0, 1] - """ - if colormap == "default": - colormap = "turbo" - - image = torch.nan_to_num(image, 0) - if colormap == "gray": - return image.repeat(1, 1, 3) - image_long = (image * 255).long() - image_long_min = torch.min(image_long) - image_long_max = torch.max(image_long) - assert image_long_min >= 0, f"the min value is {image_long_min}" - assert image_long_max <= 255, f"the max value is {image_long_max}" - return torch.tensor(matplotlib.colormaps[colormap].colors, device=image.device)[ - image_long[..., 0] - ] From 57522617e1f3a3b67c8099143c6a61ee7b57ddf5 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 09:58:04 -0700 Subject: [PATCH 27/66] benchmark script --- examples/benchmark_mcmc.sh | 57 +++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh index fddadf028..3d2e293e3 100644 --- a/examples/benchmark_mcmc.sh +++ b/examples/benchmark_mcmc.sh @@ -22,42 +22,41 @@ do echo "Running $SCENE" EVAL_STEPS="1000 7000 15000 30000" - # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - # --model_type 2dgs_inria \ - # --init_type sfm \ - # --cap_max $CAP_MAX \ - # --data_dir data/360_v2/$SCENE/ \ - # --result_dir results/2dgs_inria/$SCENE/ + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 3dgs \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --data_dir data/360_v2/$SCENE/ \ + --normal_consistency_loss \ + --result_dir results/3dgs_with_normal/$SCENE/ - # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - # --model_type 3dgs_inria \ - # --init_type sfm \ - # --cap_max $CAP_MAX \ - # --data_dir data/360_v2/$SCENE/ \ - # --result_dir results/3dgs_inria/$SCENE/ + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 2dgs_inria \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --data_dir data/360_v2/$SCENE/ \ + --normal_consistency_loss \ + --result_dir results/2dgs_inria_with_normal/$SCENE/ - # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - # --model_type 3dgs \ - # --init_type sfm \ - # --cap_max $CAP_MAX \ - # --data_dir data/360_v2/$SCENE/ \ - # --result_dir results/3dgs/$SCENE/ + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 3dgs \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir results/3dgs/$SCENE/ - # python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - # --model_type 2dgs_inria \ - # --init_type sfm \ - # --cap_max $CAP_MAX \ - # --data_dir data/360_v2/$SCENE/ \ - # --normal_consistency_loss \ - # --dist_loss \ - # --result_dir results/2dgs_inria_with_normal/$SCENE/ + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 2dgs_inria \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir results/2dgs_inria/$SCENE/ python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 3dgs \ + --model_type 3dgs_inria \ --init_type sfm \ --cap_max $CAP_MAX \ --data_dir data/360_v2/$SCENE/ \ - --normal_consistency_loss \ - --result_dir results/3dgs_with_normal/$SCENE/ + --result_dir results/3dgs_inria/$SCENE/ done From dadaf759edca0e5de877f5adaabdc5c6f8ca6ecb Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 12:26:34 -0700 Subject: [PATCH 28/66] refactor --- examples/simple_trainer_mcmc.py | 2 +- .../cuda/csrc/fully_fused_projection_bwd.cu | 8 +++- .../cuda/csrc/fully_fused_projection_fwd.cu | 16 ++++---- gsplat/cuda/csrc/utils.cuh | 39 ------------------- 4 files changed, 16 insertions(+), 49 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 0f786653c..a1277baf4 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -37,7 +37,7 @@ @dataclass class Config: - # Model type can be 3dgs, 2dgs + # Model type can be 3dgs, 3dgs_inria, or 2dgs_inria model_type: str = "3dgs" # Disable viewer disable_viewer: bool = False diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 928367106..ddcc50568 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -153,8 +153,12 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - // quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - quat_scale_to_covar_vjp_normal(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals. Please double check this is correct. + mat3 v_R = quat_to_rotmat(quat); + v_R[2] += v_normals; + quat_to_rotmat_vjp(quat, v_R, v_quat); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index cb93fb198..548468e81 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -47,8 +47,6 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, means += gid * 3; viewmats += cid * 16; Ks += cid * 9; - quats += gid * 4; - scales += gid * 3; // glm is column-major but input is row-major mat3 R = mat3(viewmats[0], viewmats[4], viewmats[8], // 1st column @@ -67,6 +65,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, // transform Gaussian covariance to camera space mat3 covar; + vec3 normal; if (covars != nullptr) { covars += gid * 6; covar = mat3(covars[0], covars[1], covars[2], // 1st column @@ -74,8 +73,13 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, covars[2], covars[4], covars[5] // 3rd column ); } else { + quats += gid * 4; + scales += gid * 3; // compute from quaternions and scales quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); + + glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + normal = rotmat[2]; } mat3 covar_c; covar_world_to_cam(R, covar, covar_c); @@ -116,16 +120,14 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, return; } - glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); - // write to outputs radii[idx] = (int32_t)radius; means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; depths[idx] = mean_c.z; - normals[idx * 3] = rotmat[2].x; - normals[idx * 3 + 1] = rotmat[2].y; - normals[idx * 3 + 2] = rotmat[2].z; + normals[idx * 3] = normal.x; + normals[idx * 3 + 1] = normal.y; + normals[idx * 3 + 2] = normal.z; conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; conics[idx * 3 + 2] = covar2d_inv[1][1]; diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 2b772e54d..50b76df1c 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -106,45 +106,6 @@ inline __device__ void quat_scale_to_covar_vjp( v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; } -template -inline __device__ void quat_scale_to_covar_vjp_normal( - // fwd inputs - const vec4 quat, const vec3 scale, - // precompute - const mat3 R, - // grad outputs - const mat3 v_covar, - const vec3 v_normal, - // grad inputs - vec4 &v_quat, vec3 &v_scale) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - T sx = scale[0], sy = scale[1], sz = scale[2]; - - // M = R * S - mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); - mat3 M = R * S; - - // https://math.stackexchange.com/a/3850121 - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - // so - // for D = M * Mt, - // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M - mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; - mat3 v_R = v_M * S; - - // add contribution from v_normal - // someone should double check this. No idea if this is correct - v_R[2] += v_normal; - - // grad for (quat, scale) from covar - quat_to_rotmat_vjp(quat, v_R, v_quat); - - v_scale[0] += R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; - v_scale[1] += R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; - v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; -} - template inline __device__ void quat_scale_to_preci_vjp( From 6800602105756f20229c65af26af1910078fa20a Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 12:33:19 -0700 Subject: [PATCH 29/66] compile bug --- gsplat/cuda/csrc/fully_fused_projection_bwd.cu | 4 ++-- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 2 +- gsplat/cuda/csrc/utils.cuh | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index ddcc50568..a437e8918 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -155,9 +155,9 @@ __global__ void fully_fused_projection_bwd_kernel( vec3 v_scale(0.f); quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - // add contribution from v_normals. Please double check this is correct. + // add contribution from v_normals. Please check if this is correct. mat3 v_R = quat_to_rotmat(quat); - v_R[2] += v_normals; + v_R[2] += glm::make_vec3(v_normals); quat_to_rotmat_vjp(quat, v_R, v_quat); warpSum(v_quat, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 548468e81..c491b1b6f 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -73,9 +73,9 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, covars[2], covars[4], covars[5] // 3rd column ); } else { + // compute from quaternions and scales quats += gid * 4; scales += gid * 3; - // compute from quaternions and scales quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 50b76df1c..57751315f 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -106,7 +106,6 @@ inline __device__ void quat_scale_to_covar_vjp( v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; } - template inline __device__ void quat_scale_to_preci_vjp( // fwd inputs From 6d156d2d447d9557e432b7f4b94fa6b3b9527187 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 8 Jul 2024 17:02:33 -0700 Subject: [PATCH 30/66] remove benchmark script --- examples/benchmark_mcmc.sh | 62 -------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 examples/benchmark_mcmc.sh diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh deleted file mode 100644 index 3d2e293e3..000000000 --- a/examples/benchmark_mcmc.sh +++ /dev/null @@ -1,62 +0,0 @@ -# for SCENE in bicycle bonsai counter garden kitchen room stump treehill flowers; -for SCENE in garden treehill; -do - if [ "$SCENE" = "bicycle" ] || [ "$SCENE" = "stump" ] || [ "$SCENE" = "garden" ] || [ "$SCENE" = "treehill" ] || [ "$SCENE" = "flowers" ]; then - DATA_FACTOR=4 - else - DATA_FACTOR=2 - fi - - if [ "$SCENE" = "bonsai" ]; then - CAP_MAX=1300000 - elif [ "$SCENE" = "counter" ]; then - CAP_MAX=1200000 - elif [ "$SCENE" = "kitchen" ]; then - CAP_MAX=1800000 - elif [ "$SCENE" = "room" ]; then - CAP_MAX=1500000 - else - CAP_MAX=3000000 - fi - - echo "Running $SCENE" - EVAL_STEPS="1000 7000 15000 30000" - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 3dgs \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --data_dir data/360_v2/$SCENE/ \ - --normal_consistency_loss \ - --result_dir results/3dgs_with_normal/$SCENE/ - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 2dgs_inria \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --data_dir data/360_v2/$SCENE/ \ - --normal_consistency_loss \ - --result_dir results/2dgs_inria_with_normal/$SCENE/ - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 3dgs \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --data_dir data/360_v2/$SCENE/ \ - --result_dir results/3dgs/$SCENE/ - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 2dgs_inria \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --data_dir data/360_v2/$SCENE/ \ - --result_dir results/2dgs_inria/$SCENE/ - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --model_type 3dgs_inria \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --data_dir data/360_v2/$SCENE/ \ - --result_dir results/3dgs_inria/$SCENE/ - -done From 9c06a245c1dab693b5815d79f727ddaacbbabaa5 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 11:44:58 -0700 Subject: [PATCH 31/66] support packed and sparse --- gsplat/cuda/_wrapper.py | 14 +++++++++++++- gsplat/cuda/csrc/bindings.h | 3 ++- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 2 +- .../csrc/fully_fused_projection_packed_bwd.cu | 17 ++++++++++++++++- .../csrc/fully_fused_projection_packed_fwd.cu | 17 +++++++++++++---- gsplat/rendering.py | 1 + 6 files changed, 46 insertions(+), 8 deletions(-) diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 1dbadf9c9..b3a31f02f 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -958,6 +958,7 @@ def forward( radii, means2d, depths, + normals, conics, compensations, ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")( @@ -994,7 +995,16 @@ def forward( ctx.eps2d = eps2d ctx.sparse_grad = sparse_grad - return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations + return ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) @staticmethod def backward( @@ -1004,6 +1014,7 @@ def backward( v_radii, v_means2d, v_depths, + v_normals, v_conics, v_compensations, ): @@ -1044,6 +1055,7 @@ def backward( compensations, v_means2d.contiguous(), v_depths.contiguous(), + v_normals.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index bde378773..4350d9810 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -180,7 +180,7 @@ compute_sh_bwd_tensor(const uint32_t K, const uint32_t degrees_to_use, * Packed Version ****************************************************************************************/ std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -210,6 +210,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, const bool sparse_grad); diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c491b1b6f..8a44269f7 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -31,7 +31,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, int32_t *__restrict__ radii, // [C, N] T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] - T *__restrict__ normals, // [C, N, 3] + T *__restrict__ normals, // [C, N, 3] T *__restrict__ conics, // [C, N, 3] T *__restrict__ compensations // [C, N] optional ) { diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index 476390814..661228c39 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -34,6 +34,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( // grad outputs const T *__restrict__ v_means2d, // [nnz, 2] const T *__restrict__ v_depths, // [nnz] + const T *__restrict__ v_normals, // [nnz, 3] const T *__restrict__ v_conics, // [nnz, 3] const T *__restrict__ v_compensations, // [nnz] optional const bool sparse_grad, // whether the outputs are in COO format [nnz, ...] @@ -61,6 +62,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_means2d += idx * 2; v_depths += idx; + v_normals += idx * 3; v_conics += idx * 3; // vjp: compute the inverse of the 2d covariance @@ -154,6 +156,11 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_scales[0] = v_scale[0]; v_scales[1] = v_scale[1]; v_scales[2] = v_scale[2]; + + // add contribution from v_normals. Please check if this is correct. + mat3 v_R = quat_to_rotmat(quat); + v_R[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_R, v_quat); } } else { // write out results with dense layout @@ -188,6 +195,12 @@ __global__ void fully_fused_projection_packed_bwd_kernel( vec4 v_quat(0.f); vec3 v_scale(0.f); quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals. Please check if this is correct. + mat3 v_R = quat_to_rotmat(quat); + v_R[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_R, v_quat); + warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); if (warp_group_g.thread_rank() == 0) { @@ -240,6 +253,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, const bool sparse_grad) { @@ -259,6 +273,7 @@ fully_fused_projection_packed_bwd_tensor( CHECK_INPUT(conics); CHECK_INPUT(v_means2d); CHECK_INPUT(v_depths); + CHECK_INPUT(v_normals); CHECK_INPUT(v_conics); if (compensations.has_value()) { CHECK_INPUT(compensations.value()); @@ -309,7 +324,7 @@ fully_fused_projection_packed_bwd_tensor( compensations.has_value() ? compensations.value().data_ptr() : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), - v_conics.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() : nullptr, sparse_grad, v_means.data_ptr(), diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 7f7082f6e..719895e38 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -36,6 +36,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( int32_t *__restrict__ radii, // [nnz] T *__restrict__ means2d, // [nnz, 2] T *__restrict__ depths, // [nnz] + T *__restrict__ normals, // [nnz, 3] T *__restrict__ conics, // [nnz, 3] T *__restrict__ compensations // [nnz] optional ) { @@ -75,6 +76,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( mat2 covar2d; vec2 mean2d; mat2 covar2d_inv; + vec3 normal; T compensation; T det; if (valid) { @@ -92,6 +94,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel( quats += col_idx * 4; scales += col_idx * 3; quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); + + glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + normal = rotmat[2]; } mat3 covar_c; covar_world_to_cam(R, covar, covar_c); @@ -163,6 +168,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel( means2d[thread_data * 2] = mean2d.x; means2d[thread_data * 2 + 1] = mean2d.y; depths[thread_data] = mean_c.z; + normals[thread_data * 3] = normal.x; + normals[thread_data * 3 + 1] = normal.y; + normals[thread_data * 3 + 2] = normal.z; conics[thread_data * 3] = covar2d_inv[0][0]; conics[thread_data * 3 + 1] = covar2d_inv[0][1]; conics[thread_data * 3 + 2] = covar2d_inv[1][1]; @@ -183,7 +191,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( } std::tuple + torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] const at::optional &covars, // [N, 6] @@ -232,7 +240,7 @@ fully_fused_projection_packed_fwd_tensor( viewmats.data_ptr(), Ks.data_ptr(), image_width, image_height, eps2d, near_plane, far_plane, radius_clip, nullptr, block_cnts.data_ptr(), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr); + nullptr, nullptr, nullptr, nullptr); block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); nnz = block_accum[-1].item(); } else { @@ -246,6 +254,7 @@ fully_fused_projection_packed_fwd_tensor( torch::Tensor radii = torch::empty({nnz}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({nnz, 2}, means.options()); torch::Tensor depths = torch::empty({nnz}, means.options()); + torch::Tensor normals = torch::empty({nnz, 3}, means.options()); torch::Tensor conics = torch::empty({nnz, 3}, means.options()); torch::Tensor compensations; if (calc_compensations) { @@ -264,12 +273,12 @@ fully_fused_projection_packed_fwd_tensor( nullptr, indptr.data_ptr(), camera_ids.data_ptr(), gaussian_ids.data_ptr(), radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), - conics.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr); } else { indptr.fill_(0); } return std::make_tuple(indptr, camera_ids, gaussian_ids, radii, means2d, depths, - conics, compensations); + normals, conics, compensations); } diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 40d548dc3..3ef25bf6e 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -242,6 +242,7 @@ def rasterization( radii, means2d, depths, + normals, conics, compensations, ) = proj_results From 7294be4b58b1eb57fe173e2ba29e239464ac5b68 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 12:33:23 -0700 Subject: [PATCH 32/66] refactor --- examples/simple_trainer_mcmc.py | 51 +++++------------- gsplat/normal_utils.py | 3 +- gsplat/rendering.py | 92 +++++++++++++++++++++++---------- 3 files changed, 81 insertions(+), 65 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index a1277baf4..b0ae95195 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -31,7 +31,6 @@ rasterization_inria_wrapper, ) from gsplat.relocation import compute_relocation -from gsplat.normal_utils import depth_to_normal from simple_trainer import create_splats_with_optimizers @@ -185,8 +184,10 @@ def __init__(self, cfg: Config) -> None: raise ValueError(f"Unsupported model type: {cfg.model_type}") self.render_mode = "RGB" - if cfg.depth_loss or cfg.normal_consistency_loss: + if cfg.depth_loss: self.render_mode = "RGB+ED" + if cfg.normal_consistency_loss: + self.render_mode = "RGB+ED+N" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -460,18 +461,10 @@ def train(self): loss += depthloss * cfg.depth_lambda if cfg.normal_consistency_loss: - depths = renders[..., -1:] - normals = renders[..., -4:-1] - normals_surf = depth_to_normal( - depths, - camtoworlds, - Ks, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - ) - normals_surf = normals_surf * (alphas).detach() + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] normalconsistencyloss = ( - 1 - (normals * normals_surf).sum(dim=-1) + 1 - (normals_rend * normals_surf).sum(dim=-1) ).mean() if step > cfg.normal_consistency_start_iter: loss += normalconsistencyloss * cfg.normal_consistency_lambda @@ -729,7 +722,7 @@ def eval(self, step: int): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode=self.render_mode, - ) # [1, H, W, C] + ) # [1, H, W, K] torch.cuda.synchronize() ellipse_time += time.time() - tic @@ -740,17 +733,9 @@ def eval(self, step: int): depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths) if cfg.normal_consistency_loss: - depths = renders[..., -1:] - normals = renders[..., -4:-1] - normals_surf = depth_to_normal( - depths, - camtoworlds, - Ks, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - ) - normals_surf = normals_surf * (alphas).detach() - canvas_list.extend([normals * 0.5 + 0.5]) + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) # write images @@ -826,7 +811,7 @@ def render_traj(self, step: int): near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode=self.render_mode, - ) # [1, H, W, C] + ) # [1, H, W, K] colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [colors] @@ -835,17 +820,9 @@ def render_traj(self, step: int): depths = (depths - depths.min()) / (depths.max() - depths.min()) canvas_list.append(depths) if cfg.normal_consistency_loss: - depths = renders[..., -1:] - normals = renders[..., -4:-1] - normals_surf = depth_to_normal( - depths, - camtoworlds, - Ks, - near_plane=cfg.near_plane, - far_plane=cfg.far_plane, - ) - normals_surf = normals_surf * (alphas).detach() - canvas_list.extend([normals * 0.5 + 0.5]) + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) # write images diff --git a/gsplat/normal_utils.py b/gsplat/normal_utils.py index d8c30ed92..5124201f1 100644 --- a/gsplat/normal_utils.py +++ b/gsplat/normal_utils.py @@ -43,9 +43,8 @@ def _depth_to_normal(depth, world_view_transform, full_proj_transform): return output -def depth_to_normal(depths, camtoworlds, Ks, near_plane, far_plane): +def depth_to_normal(depths, viewmats, Ks, near_plane, far_plane): height, width = depths.shape[1:3] - viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] normals = [] for cid, depth in enumerate(depths): diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 3ef25bf6e..e2204f352 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -13,6 +13,7 @@ rasterize_to_pixels, spherical_harmonics, ) +from .normal_utils import depth_to_normal def rasterization( @@ -33,7 +34,7 @@ def rasterization( packed: bool = True, tile_size: int = 16, backgrounds: Optional[Tensor] = None, - render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"] = "RGB", sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", @@ -198,7 +199,7 @@ def rasterization( assert opacities.shape == (N,), opacities.shape assert viewmats.shape == (C, 4, 4), viewmats.shape assert Ks.shape == (C, 3, 3), Ks.shape - assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"], render_mode if sh_degree is None: # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] @@ -318,7 +319,7 @@ def rasterization( # Rasterize to pixels if render_mode in ["RGB+D", "RGB+ED"]: - colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + colors = torch.cat((colors, depths[..., None]), dim=-1) if backgrounds is not None: backgrounds = torch.cat( [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 @@ -327,6 +328,12 @@ def rasterization( colors = depths[..., None] if backgrounds is not None: backgrounds = torch.zeros(C, 1, device=backgrounds.device) + elif render_mode in ["RGB+ED+N"]: + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + if backgrounds is not None: + backgrounds = torch.cat( + [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + ) else: # RGB pass if colors.shape[-1] > channel_chunk: @@ -373,35 +380,55 @@ def rasterization( packed=packed, absgrad=absgrad, ) - if render_mode in ["ED", "RGB+ED"]: + + meta = {} + if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: # normalize the accumulated depth to get the expected depth + depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) render_colors = torch.cat( [ render_colors[..., :3], - render_colors[..., -4:-1], - render_colors[..., -1:] / render_alphas.clamp(min=1e-10), + depths_expected, ], dim=-1, ) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + depths_expected, + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } + ) - meta = { - "camera_ids": camera_ids, - "gaussian_ids": gaussian_ids, - "radii": radii, - "means2d": means2d, - "depths": depths, - "conics": conics, - "opacities": opacities, - "tile_width": tile_width, - "tile_height": tile_height, - "tiles_per_gauss": tiles_per_gauss, - "isect_ids": isect_ids, - "flatten_ids": flatten_ids, - "isect_offsets": isect_offsets, - "width": width, - "height": height, - "tile_size": tile_size, - } + meta.update( + { + "camera_ids": camera_ids, + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "conics": conics, + "opacities": opacities, + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + } + ) return render_colors, render_alphas, meta @@ -777,6 +804,19 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - meta = {"render_distloss": render_dist} - render_colors = torch.cat([render_colors, render_normal, render_depth], dim=-1) + normals_surf = depth_to_normal( + render_depth, + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + + meta = { + "normals_rend": render_normal, + "normals_surf": normals_surf, + "render_distloss": render_dist, + } + render_colors = torch.cat([render_colors, render_depth], dim=-1) return render_colors, render_alphas, meta From 0a9ce04e491d1e8ee4e716917332becfef0215a1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 12:51:07 -0700 Subject: [PATCH 33/66] bugfix --- gsplat/rendering.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index e2204f352..f56a2100e 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -385,13 +385,6 @@ def rasterization( if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: # normalize the accumulated depth to get the expected depth depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) - render_colors = torch.cat( - [ - render_colors[..., :3], - depths_expected, - ], - dim=-1, - ) if render_mode in ["RGB+ED+N"]: normals_rend = render_colors[..., -4:-1] normals_surf = depth_to_normal( @@ -409,6 +402,14 @@ def rasterization( } ) + render_colors = torch.cat( + [ + render_colors[..., :3], + depths_expected, + ], + dim=-1, + ) + meta.update( { "camera_ids": camera_ids, From 1bd9e9170c23d9921d141b4a6cd4e09d9d38c2c7 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 13:47:34 -0700 Subject: [PATCH 34/66] utils/ --- gsplat/rendering.py | 129 +++++++++-------------------- gsplat/utils/__init__.py | 0 gsplat/utils/camera_utils.py | 25 ++++++ gsplat/{ => utils}/normal_utils.py | 29 +------ 4 files changed, 65 insertions(+), 118 deletions(-) create mode 100644 gsplat/utils/__init__.py create mode 100644 gsplat/utils/camera_utils.py rename gsplat/{ => utils}/normal_utils.py (74%) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index f56a2100e..bb989786f 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -13,7 +13,8 @@ rasterize_to_pixels, spherical_harmonics, ) -from .normal_utils import depth_to_normal +from .utils.normal_utils import depth_to_normal +from .utils.camera_utils import _getProjectionMatrix def rasterization( @@ -381,55 +382,43 @@ def rasterization( absgrad=absgrad, ) - meta = {} + meta = { + "camera_ids": camera_ids, + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "conics": conics, + "opacities": opacities, + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + } if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: # normalize the accumulated depth to get the expected depth - depths_expected = render_colors[..., -1:] / render_alphas.clamp(min=1e-10) - if render_mode in ["RGB+ED+N"]: - normals_rend = render_colors[..., -4:-1] - normals_surf = depth_to_normal( - depths_expected, - viewmats, - Ks, - near_plane=near_plane, - far_plane=far_plane, - ) - normals_surf = normals_surf * (render_alphas).detach() - meta.update( - { - "normals_rend": normals_rend, - "normals_surf": normals_surf, - } - ) - - render_colors = torch.cat( - [ - render_colors[..., :3], - depths_expected, - ], - dim=-1, + render_colors[..., -1:] /= render_alphas.clamp(min=1e-10) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + render_colors[..., -1:], + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } ) - - meta.update( - { - "camera_ids": camera_ids, - "gaussian_ids": gaussian_ids, - "radii": radii, - "means2d": means2d, - "depths": depths, - "conics": conics, - "opacities": opacities, - "tile_width": tile_width, - "tile_height": tile_height, - "tiles_per_gauss": tiles_per_gauss, - "isect_ids": isect_ids, - "flatten_ids": flatten_ids, - "isect_offsets": isect_offsets, - "width": width, - "height": height, - "tile_size": tile_size, - } - ) return render_colors, render_alphas, meta @@ -557,28 +546,6 @@ def rasterization_inria_wrapper( GaussianRasterizer, ) - def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) - - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right - - P = torch.zeros(4, 4, device=device) - - z_sign = 1.0 - - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P - assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" C = len(viewmats) device = means.device @@ -685,28 +652,6 @@ def rasterization_2dgs_inria_wrapper( GaussianRasterizer, ) - def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) - - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right - - P = torch.zeros(4, 4, device=device) - - z_sign = 1.0 - - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P - assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" C = len(viewmats) device = means.device @@ -814,10 +759,10 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): ) normals_surf = normals_surf * (render_alphas).detach() + render_colors = torch.cat([render_colors, render_depth], dim=-1) meta = { "normals_rend": render_normal, "normals_surf": normals_surf, "render_distloss": render_dist, } - render_colors = torch.cat([render_colors, render_depth], dim=-1) return render_colors, render_alphas, meta diff --git a/gsplat/utils/__init__.py b/gsplat/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gsplat/utils/camera_utils.py b/gsplat/utils/camera_utils.py new file mode 100644 index 000000000..b57cd2e3d --- /dev/null +++ b/gsplat/utils/camera_utils.py @@ -0,0 +1,25 @@ +import math +import torch + + +def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4, device=device) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P diff --git a/gsplat/normal_utils.py b/gsplat/utils/normal_utils.py similarity index 74% rename from gsplat/normal_utils.py rename to gsplat/utils/normal_utils.py index 5124201f1..0615a3f8e 100644 --- a/gsplat/normal_utils.py +++ b/gsplat/utils/normal_utils.py @@ -1,9 +1,9 @@ import torch -import torch.nn as nn import torch.nn.functional as F -import numpy as np import math +from .camera_utils import _getProjectionMatrix + def _depths_to_points(depthmap, world_view_transform, full_proj_transform): c2w = (world_view_transform.T).inverse() @@ -38,7 +38,7 @@ def _depth_to_normal(depth, world_view_transform, full_proj_transform): output = torch.zeros_like(points) dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) - normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) + normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) output[1:-1, 1:-1, :] = normal_map return output @@ -61,26 +61,3 @@ def depth_to_normal(depths, viewmats, Ks, near_plane, far_plane): normals.append(normal) normals = torch.stack(normals, dim=0) return normals - - -def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) - - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right - - P = torch.zeros(4, 4, device=device) - - z_sign = 1.0 - - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P From d7a291660d1080a32da80f4da47ea035a35392f2 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 13:56:18 -0700 Subject: [PATCH 35/66] rasterization backend --- examples/simple_trainer_mcmc.py | 14 ++++++++------ gsplat/rendering.py | 6 +++--- gsplat/utils/camera_utils.py | 2 +- gsplat/utils/normal_utils.py | 13 ++++++++++--- 4 files changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index b0ae95195..be19e3938 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -36,8 +36,8 @@ @dataclass class Config: - # Model type can be 3dgs, 3dgs_inria, or 2dgs_inria - model_type: str = "3dgs" + # Rasterization backend can be 3dgs, 3dgs_inria, or 2dgs_inria + rasterization_backend: str = "3dgs" # Disable viewer disable_viewer: bool = False # Path to the .pt file. If provide, it will skip training and render a video @@ -174,14 +174,16 @@ def __init__(self, cfg: Config) -> None: self.cfg = cfg self.device = "cuda" - if cfg.model_type == "3dgs": + if cfg.rasterization_backend == "3dgs": self.rasterization_fn = rasterization - elif cfg.model_type == "3dgs_inria": + elif cfg.rasterization_backend == "3dgs_inria": self.rasterization_fn = rasterization_inria_wrapper - elif cfg.model_type == "2dgs_inria": + elif cfg.rasterization_backend == "2dgs_inria": self.rasterization_fn = rasterization_2dgs_inria_wrapper else: - raise ValueError(f"Unsupported model type: {cfg.model_type}") + raise ValueError( + f"Unsupported rasterization backend: {cfg.rasterization_backend}" + ) self.render_mode = "RGB" if cfg.depth_loss: diff --git a/gsplat/rendering.py b/gsplat/rendering.py index bb989786f..cd6ff30a5 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -14,7 +14,7 @@ spherical_harmonics, ) from .utils.normal_utils import depth_to_normal -from .utils.camera_utils import _getProjectionMatrix +from .utils.camera_utils import getProjectionMatrix def rasterization( @@ -562,7 +562,7 @@ def rasterization_inria_wrapper( tanfovy = math.tan(FoVy * 0.5) world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = _getProjectionMatrix( + projection_matrix = getProjectionMatrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device ).transpose(0, 1) full_proj_transform = ( @@ -669,7 +669,7 @@ def rasterization_2dgs_inria_wrapper( tanfovy = math.tan(FoVy * 0.5) world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = _getProjectionMatrix( + projection_matrix = getProjectionMatrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device ).transpose(0, 1) full_proj_transform = ( diff --git a/gsplat/utils/camera_utils.py b/gsplat/utils/camera_utils.py index b57cd2e3d..924e87115 100644 --- a/gsplat/utils/camera_utils.py +++ b/gsplat/utils/camera_utils.py @@ -2,7 +2,7 @@ import torch -def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): +def getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) diff --git a/gsplat/utils/normal_utils.py b/gsplat/utils/normal_utils.py index 0615a3f8e..9ca55c8c9 100644 --- a/gsplat/utils/normal_utils.py +++ b/gsplat/utils/normal_utils.py @@ -1,8 +1,9 @@ import torch import torch.nn.functional as F import math +from torch import Tensor -from .camera_utils import _getProjectionMatrix +from .camera_utils import getProjectionMatrix def _depths_to_points(depthmap, world_view_transform, full_proj_transform): @@ -43,7 +44,13 @@ def _depth_to_normal(depth, world_view_transform, full_proj_transform): return output -def depth_to_normal(depths, viewmats, Ks, near_plane, far_plane): +def depth_to_normal( + depths: Tensor, # [C, H, W, 1] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + near_plane: float = 0.01, + far_plane: float = 1e10, +) -> Tensor: height, width = depths.shape[1:3] normals = [] @@ -51,7 +58,7 @@ def depth_to_normal(depths, viewmats, Ks, near_plane, far_plane): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = _getProjectionMatrix( + projection_matrix = getProjectionMatrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device ).transpose(0, 1) full_proj_transform = ( From 0274d1b1e9e1372acd4f5791a6cc5032c1f8466e Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 14:40:47 -0700 Subject: [PATCH 36/66] v_rotmat --- gsplat/cuda/csrc/fully_fused_projection_bwd.cu | 6 +++--- .../csrc/fully_fused_projection_packed_bwd.cu | 17 +++++++++-------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index a437e8918..73345cbde 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -156,9 +156,9 @@ __global__ void fully_fused_projection_bwd_kernel( quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); // add contribution from v_normals. Please check if this is correct. - mat3 v_R = quat_to_rotmat(quat); - v_R[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_R, v_quat); + mat3 v_rotmat = quat_to_rotmat(quat); + v_rotmat[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_rotmat, v_quat); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index 661228c39..fc9994900 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -147,6 +147,12 @@ __global__ void fully_fused_projection_packed_bwd_kernel( vec4 v_quat(0.f); vec3 v_scale(0.f); quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals. Please check if this is correct. + mat3 v_rotmat = quat_to_rotmat(quat); + v_rotmat[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + v_quats += idx * 4; v_scales += idx * 3; v_quats[0] = v_quat[0]; @@ -156,11 +162,6 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_scales[0] = v_scale[0]; v_scales[1] = v_scale[1]; v_scales[2] = v_scale[2]; - - // add contribution from v_normals. Please check if this is correct. - mat3 v_R = quat_to_rotmat(quat); - v_R[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_R, v_quat); } } else { // write out results with dense layout @@ -197,9 +198,9 @@ __global__ void fully_fused_projection_packed_bwd_kernel( quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); // add contribution from v_normals. Please check if this is correct. - mat3 v_R = quat_to_rotmat(quat); - v_R[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_R, v_quat); + mat3 v_rotmat = quat_to_rotmat(quat); + v_rotmat[2] += glm::make_vec3(v_normals); + quat_to_rotmat_vjp(quat, v_rotmat, v_quat); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); From 6840727ac97a969aa61d0aae05bac9918f94ab21 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 14:50:03 -0700 Subject: [PATCH 37/66] render traj --- examples/simple_trainer_mcmc.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index be19e3938..ad9041bc8 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -14,7 +14,7 @@ import viser import nerfview from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path +from datasets.traj import generate_interpolated_path, generate_ellipse_path_z from torch import Tensor from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure @@ -42,6 +42,8 @@ class Config: disable_viewer: bool = False # Path to the .pt file. If provide, it will skip training and render a video ckpt: Optional[str] = None + # Render trajectory path + render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset data_dir: str = "data/360_v2/garden" @@ -784,7 +786,20 @@ def render_traj(self, step: int): device = self.device camtoworlds_all = self.parser.camtoworlds[5:-5] - camtoworlds_all = generate_interpolated_path(camtoworlds_all, 1) # [N, 3, 4] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + camtoworlds_all = np.concatenate( [ camtoworlds_all, From 5fdc9345e03c59473e2c61d41fed4e926d13115b Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 14:56:11 -0700 Subject: [PATCH 38/66] point inward during ellipse --- examples/datasets/traj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/datasets/traj.py b/examples/datasets/traj.py index 8d49aa711..8fcc981b2 100644 --- a/examples/datasets/traj.py +++ b/examples/datasets/traj.py @@ -90,7 +90,7 @@ def get_positions(theta): ind_up = np.argmax(np.abs(avg_up)) up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) - return np.stack([viewmatrix(p - center, up, p) for p in positions]) + return np.stack([viewmatrix(center - p, up, p) for p in positions]) def generate_ellipse_path_y( From e3dc5e509e926236786f8a9f60ae35651d248b7f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 21:30:24 -0700 Subject: [PATCH 39/66] add tests for fwd and bwd pass --- gsplat/cuda/_torch_impl.py | 32 ++++++++++----- .../cuda/csrc/fully_fused_projection_bwd.cu | 8 +--- gsplat/cuda/csrc/utils.cuh | 41 +++++++++++++++++++ tests/test_basic.py | 29 +++++++++---- 4 files changed, 84 insertions(+), 26 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index e73cdd86a..58b52aca1 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -6,14 +6,7 @@ from torch import Tensor -def _quat_scale_to_covar_preci( - quats: Tensor, # [N, 4], - scales: Tensor, # [N, 3], - compute_covar: bool = True, - compute_preci: bool = True, - triu: bool = False, -) -> Tuple[Optional[Tensor], Optional[Tensor]]: - """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`.""" +def _quat_to_rotmat(quats: Tensor) -> Tensor: quats = F.normalize(quats, p=2, dim=-1) w, x, y, z = torch.unbind(quats, dim=-1) R = torch.stack( @@ -30,8 +23,19 @@ def _quat_scale_to_covar_preci( ], dim=-1, ) - R = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3) + return R + + +def _quat_scale_to_covar_preci( + quats: Tensor, # [N, 4], + scales: Tensor, # [N, 3], + compute_covar: bool = True, + compute_preci: bool = True, + triu: bool = False, +) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`.""" + R = _quat_to_rotmat(quats) # R.register_hook(lambda grad: print("grad R", grad)) if compute_covar: @@ -129,7 +133,8 @@ def _world_to_cam( def _fully_fused_projection( means: Tensor, # [N, 3] - covars: Tensor, # [N, 3, 3] + quats: Tensor, + scales: Tensor, viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, @@ -146,6 +151,11 @@ def _fully_fused_projection( This is a minimal implementation of fully fused version, which has more arguments. Not all arguments are supported. """ + + covars, _ = _quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] + normals = _quat_to_rotmat(quats)[..., 2] + normals = normals.repeat(viewmats.shape[0], 1, 1) + means_c, covars_c = _world_to_cam(means, covars, viewmats) means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height) det_orig = ( @@ -194,7 +204,7 @@ def _fully_fused_projection( radius[~inside] = 0.0 radii = radius.int() - return radii, means2d, depths, conics, compensations + return radii, means2d, depths, normals, conics, compensations @torch.no_grad() diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 73345cbde..928367106 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -153,12 +153,8 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - - // add contribution from v_normals. Please check if this is correct. - mat3 v_rotmat = quat_to_rotmat(quat); - v_rotmat[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + // quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + quat_scale_to_covar_vjp_normal(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 57751315f..bd75caad8 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -106,6 +106,47 @@ inline __device__ void quat_scale_to_covar_vjp( v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; } +template +inline __device__ void quat_scale_to_covar_vjp_normal( + // fwd inputs + const vec4 quat, const vec3 scale, + // precompute + const mat3 R, + // grad outputs + const mat3 v_covar, + const vec3 v_normal, + // grad inputs + vec4 &v_quat, vec3 &v_scale) { + T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; + T sx = scale[0], sy = scale[1], sz = scale[2]; + + // M = R * S + mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); + mat3 M = R * S; + + // https://math.stackexchange.com/a/3850121 + // for D = W * X, G = df/dD + // df/dW = G * XT, df/dX = WT * G + // so + // for D = M * Mt, + // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M + mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; + mat3 v_R = v_M * S; + + v_R[2] += v_normal; + + // // add contribution from v_normals. Please check if this is correct. + // mat3 v_rotmat = quat_to_rotmat(quat); + // quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + + // grad for (quat, scale) from covar + quat_to_rotmat_vjp(quat, v_R, v_quat); + + v_scale[0] += R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; + v_scale[1] += R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; + v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; +} + template inline __device__ void quat_scale_to_preci_vjp( // fwd inputs diff --git a/tests/test_basic.py b/tests/test_basic.py index 8c546b450..a4b367f85 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -159,10 +159,10 @@ def test_persp_proj(test_data): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -@pytest.mark.parametrize("fused", [False, True]) -@pytest.mark.parametrize("calc_compensations", [False, True]) +@pytest.mark.parametrize("fused", [True]) +@pytest.mark.parametrize("calc_compensations", [True]) def test_projection(test_data, fused: bool, calc_compensations: bool): - from gsplat.cuda._torch_impl import _fully_fused_projection + from gsplat.cuda._torch_impl import _fully_fused_projection, _quat_to_rotmat from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci torch.manual_seed(42) @@ -181,7 +181,7 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): # forward if fused: - radii, means2d, depths, conics, compensations = fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( means, None, quats, @@ -194,7 +194,7 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): ) else: covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] - radii, means2d, depths, conics, compensations = fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( means, covars, None, @@ -205,10 +205,17 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): height, calc_compensations=calc_compensations, ) - _covars, _ = quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] - _radii, _means2d, _depths, _conics, _compensations = _fully_fused_projection( + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = _fully_fused_projection( means, - _covars, + quats, + scales, viewmats, Ks, width, @@ -221,6 +228,7 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): torch.testing.assert_close(radii, _radii, rtol=0, atol=1) torch.testing.assert_close(means2d[valid], _means2d[valid], rtol=1e-4, atol=1e-4) torch.testing.assert_close(depths[valid], _depths[valid], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(normals[valid], _normals[valid], rtol=1e-4, atol=1e-4) torch.testing.assert_close(conics[valid], _conics[valid], rtol=1e-4, atol=1e-4) if calc_compensations: torch.testing.assert_close( @@ -230,12 +238,14 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): # backward v_means2d = torch.randn_like(means2d) * radii[..., None] v_depths = torch.randn_like(depths) * radii + v_normals = torch.randn_like(normals) * radii[..., None] v_conics = torch.randn_like(conics) * radii[..., None] if calc_compensations: v_compensations = torch.randn_like(compensations) * radii v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d).sum() + (depths * v_depths).sum() + + (normals * v_normals).sum() + (conics * v_conics).sum() + ((compensations * v_compensations).sum() if calc_compensations else 0), (viewmats, quats, scales, means), @@ -243,12 +253,13 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): _v_viewmats, _v_quats, _v_scales, _v_means = torch.autograd.grad( (_means2d * v_means2d).sum() + (_depths * v_depths).sum() + + (_normals * v_normals).sum() + (_conics * v_conics).sum() + ((_compensations * v_compensations).sum() if calc_compensations else 0), (viewmats, quats, scales, means), ) - torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-2, atol=1e-2) torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-2) torch.testing.assert_close(v_scales, _v_scales, rtol=1e-1, atol=2e-1) torch.testing.assert_close(v_means, _v_means, rtol=1e-2, atol=6e-2) From 467ffe685c0dbbeac433711b72b89b213540d1d6 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Mon, 15 Jul 2024 22:40:57 -0700 Subject: [PATCH 40/66] all but one test passing --- gsplat/cuda/_torch_impl.py | 4 +- .../cuda/csrc/fully_fused_projection_bwd.cu | 3 +- .../csrc/fully_fused_projection_packed_bwd.cu | 14 +- gsplat/cuda/csrc/utils.cuh | 7 +- tests/test_basic.py | 177 +++++++----------- 5 files changed, 76 insertions(+), 129 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 58b52aca1..f6572bb93 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -151,10 +151,8 @@ def _fully_fused_projection( This is a minimal implementation of fully fused version, which has more arguments. Not all arguments are supported. """ - + normals = _quat_to_rotmat(quats)[..., 2].repeat(viewmats.shape[0], 1, 1) covars, _ = _quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] - normals = _quat_to_rotmat(quats)[..., 2] - normals = normals.repeat(viewmats.shape[0], 1, 1) means_c, covars_c = _world_to_cam(means, covars, viewmats) means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 928367106..5246c2f38 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -153,8 +153,7 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - // quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - quat_scale_to_covar_vjp_normal(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); + quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index fc9994900..b1531083e 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -146,12 +146,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - - // add contribution from v_normals. Please check if this is correct. - mat3 v_rotmat = quat_to_rotmat(quat); - v_rotmat[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); v_quats += idx * 4; v_scales += idx * 3; @@ -195,12 +190,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); - - // add contribution from v_normals. Please check if this is correct. - mat3 v_rotmat = quat_to_rotmat(quat); - v_rotmat[2] += glm::make_vec3(v_normals); - quat_to_rotmat_vjp(quat, v_rotmat, v_quat); + quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index bd75caad8..47416dcf7 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -107,7 +107,7 @@ inline __device__ void quat_scale_to_covar_vjp( } template -inline __device__ void quat_scale_to_covar_vjp_normal( +inline __device__ void quat_scale_to_covar_normal_vjp( // fwd inputs const vec4 quat, const vec3 scale, // precompute @@ -133,12 +133,9 @@ inline __device__ void quat_scale_to_covar_vjp_normal( mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; mat3 v_R = v_M * S; + // add contribution from v_normal v_R[2] += v_normal; - // // add contribution from v_normals. Please check if this is correct. - // mat3 v_rotmat = quat_to_rotmat(quat); - // quat_to_rotmat_vjp(quat, v_rotmat, v_quat); - // grad for (quat, scale) from covar quat_to_rotmat_vjp(quat, v_R, v_quat); diff --git a/tests/test_basic.py b/tests/test_basic.py index a4b367f85..5a11c97e4 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -159,9 +159,8 @@ def test_persp_proj(test_data): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -@pytest.mark.parametrize("fused", [True]) -@pytest.mark.parametrize("calc_compensations", [True]) -def test_projection(test_data, fused: bool, calc_compensations: bool): +@pytest.mark.parametrize("calc_compensations", [False, True]) +def test_projection(test_data, calc_compensations: bool): from gsplat.cuda._torch_impl import _fully_fused_projection, _quat_to_rotmat from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci @@ -180,31 +179,17 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): means.requires_grad = True # forward - if fused: - radii, means2d, depths, normals, conics, compensations = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - calc_compensations=calc_compensations, - ) - else: - covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] - radii, means2d, depths, normals, conics, compensations = fully_fused_projection( - means, - covars, - None, - None, - viewmats, - Ks, - width, - height, - calc_compensations=calc_compensations, - ) + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + calc_compensations=calc_compensations, + ) ( _radii, _means2d, @@ -260,19 +245,18 @@ def test_projection(test_data, fused: bool, calc_compensations: bool): ) torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-2) - torch.testing.assert_close(v_scales, _v_scales, rtol=1e-1, atol=2e-1) + torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-1) + torch.testing.assert_close(v_scales, _v_scales, rtol=2e-1, atol=2e-1) torch.testing.assert_close(v_means, _v_means, rtol=1e-2, atol=6e-2) @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -@pytest.mark.parametrize("fused", [False, True]) -@pytest.mark.parametrize("sparse_grad", [False, True]) -@pytest.mark.parametrize("calc_compensations", [False, True]) +@pytest.mark.parametrize("sparse_grad", [True]) +@pytest.mark.parametrize("calc_compensations", [False]) def test_fully_fused_projection_packed( - test_data, fused: bool, sparse_grad: bool, calc_compensations: bool + test_data, sparse_grad: bool, calc_compensations: bool ): - from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci + from gsplat.cuda._wrapper import fully_fused_projection torch.manual_seed(42) @@ -289,75 +273,47 @@ def test_fully_fused_projection_packed( means.requires_grad = True # forward - if fused: - ( - camera_ids, - gaussian_ids, - radii, - means2d, - depths, - conics, - compensations, - ) = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - packed=True, - sparse_grad=sparse_grad, - calc_compensations=calc_compensations, - ) - _radii, _means2d, _depths, _conics, _compensations = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - packed=False, - calc_compensations=calc_compensations, - ) - else: - covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] - ( - camera_ids, - gaussian_ids, - radii, - means2d, - depths, - conics, - compensations, - ) = fully_fused_projection( - means, - covars, - None, - None, - viewmats, - Ks, - width, - height, - packed=True, - sparse_grad=sparse_grad, - calc_compensations=calc_compensations, - ) - _radii, _means2d, _depths, _conics, _compensations = fully_fused_projection( - means, - covars, - None, - None, - viewmats, - Ks, - width, - height, - packed=False, - calc_compensations=calc_compensations, - ) + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=True, + sparse_grad=sparse_grad, + calc_compensations=calc_compensations, + ) + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=False, + calc_compensations=calc_compensations, + ) # recover packed tensors to full matrices for testing __radii = torch.sparse_coo_tensor( @@ -369,6 +325,9 @@ def test_fully_fused_projection_packed( __depths = torch.sparse_coo_tensor( torch.stack([camera_ids, gaussian_ids]), depths, _depths.shape ).to_dense() + __normals = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), normals, _normals.shape + ).to_dense() __conics = torch.sparse_coo_tensor( torch.stack([camera_ids, gaussian_ids]), conics, _conics.shape ).to_dense() @@ -380,6 +339,7 @@ def test_fully_fused_projection_packed( torch.testing.assert_close(__radii[sel], _radii[sel], rtol=0, atol=1) torch.testing.assert_close(__means2d[sel], _means2d[sel], rtol=1e-4, atol=1e-4) torch.testing.assert_close(__depths[sel], _depths[sel], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(__normals[sel], _normals[sel], rtol=1e-4, atol=1e-4) torch.testing.assert_close(__conics[sel], _conics[sel], rtol=1e-4, atol=1e-4) if calc_compensations: torch.testing.assert_close( @@ -389,10 +349,12 @@ def test_fully_fused_projection_packed( # backward v_means2d = torch.randn_like(_means2d) * sel[..., None] v_depths = torch.randn_like(_depths) * sel + v_normals = torch.randn_like(_normals) * sel[..., None] v_conics = torch.randn_like(_conics) * sel[..., None] _v_viewmats, _v_quats, _v_scales, _v_means = torch.autograd.grad( (_means2d * v_means2d).sum() + (_depths * v_depths).sum() + + (_normals * v_normals).sum() + (_conics * v_conics).sum(), (viewmats, quats, scales, means), retain_graph=True, @@ -400,6 +362,7 @@ def test_fully_fused_projection_packed( v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d[sel]).sum() + (depths * v_depths[sel]).sum() + + (normals * v_normals[sel]).sum() + (conics * v_conics[sel]).sum(), (viewmats, quats, scales, means), retain_graph=True, @@ -475,11 +438,11 @@ def test_rasterize_to_pixels(test_data, channels: int): colors = torch.randn(C, len(means), channels, device=device) backgrounds = torch.rand((C, colors.shape[-1]), device=device) - covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) + # covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) # Project Gaussians to 2D - radii, means2d, depths, conics, compensations = fully_fused_projection( - means, covars, None, None, viewmats, Ks, width, height + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( + means, None, quats, scales, viewmats, Ks, width, height ) opacities = opacities.repeat(C, 1) From 8e158a2bff1d2759be4be77e5eeed5e95ea1fbf9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 00:40:36 -0700 Subject: [PATCH 41/66] weird bug --- gsplat/cuda/csrc/fully_fused_projection_fwd.cu | 2 +- gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu | 2 +- tests/test_basic.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index 8a44269f7..9d13deaa1 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -78,7 +78,7 @@ fully_fused_projection_fwd_kernel(const uint32_t C, const uint32_t N, scales += gid * 3; quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); - glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); normal = rotmat[2]; } mat3 covar_c; diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 719895e38..04fd823ba 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -95,7 +95,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( scales += col_idx * 3; quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); - glm::mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); normal = rotmat[2]; } mat3 covar_c; diff --git a/tests/test_basic.py b/tests/test_basic.py index 5a11c97e4..8c1b5b15e 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -251,8 +251,8 @@ def test_projection(test_data, calc_compensations: bool): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -@pytest.mark.parametrize("sparse_grad", [True]) -@pytest.mark.parametrize("calc_compensations", [False]) +@pytest.mark.parametrize("sparse_grad", [False, True]) +@pytest.mark.parametrize("calc_compensations", [False, True]) def test_fully_fused_projection_packed( test_data, sparse_grad: bool, calc_compensations: bool ): @@ -260,8 +260,8 @@ def test_fully_fused_projection_packed( torch.manual_seed(42) - Ks = test_data["Ks"] - viewmats = test_data["viewmats"] + Ks = test_data["Ks"][:2] + viewmats = test_data["viewmats"][:2] height = test_data["height"] width = test_data["width"] quats = test_data["quats"] From 08b4631c978cfe7278ead51a3eb9487789331c9e Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 13:28:46 -0700 Subject: [PATCH 42/66] test passes but test suite does not pass --- gsplat/cuda/csrc/fully_fused_projection_bwd.cu | 3 ++- gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu | 5 ++++- tests/test_basic.py | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 5246c2f38..2ef12278d 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -60,7 +60,6 @@ __global__ void fully_fused_projection_bwd_kernel( v_means2d += idx * 2; v_depths += idx; - v_normals += idx * 3; v_conics += idx * 3; // vjp: compute the inverse of the 2d covariance @@ -153,6 +152,8 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + + v_normals += idx * 3; quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index b1531083e..734f41e9a 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -62,7 +62,6 @@ __global__ void fully_fused_projection_packed_bwd_kernel( v_means2d += idx * 2; v_depths += idx; - v_normals += idx * 3; v_conics += idx * 3; // vjp: compute the inverse of the 2d covariance @@ -146,6 +145,8 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + + v_normals += idx * 3; quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); v_quats += idx * 4; @@ -190,6 +191,8 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + + v_normals += idx * 3; quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); warpSum(v_quat, warp_group_g); diff --git a/tests/test_basic.py b/tests/test_basic.py index 8c1b5b15e..1cf1b0a61 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -260,8 +260,8 @@ def test_fully_fused_projection_packed( torch.manual_seed(42) - Ks = test_data["Ks"][:2] - viewmats = test_data["viewmats"][:2] + Ks = test_data["Ks"] + viewmats = test_data["viewmats"] height = test_data["height"] width = test_data["width"] quats = test_data["quats"] From c7e2fd4b0556f3adf5167723732d5e0bd17540d9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 14:55:49 -0700 Subject: [PATCH 43/66] count radii --- tests/test_basic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_basic.py b/tests/test_basic.py index 1cf1b0a61..844876f75 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -359,6 +359,9 @@ def test_fully_fused_projection_packed( (viewmats, quats, scales, means), retain_graph=True, ) + + sel = (__radii > 0) + print(torch.count_nonzero(_radii), torch.count_nonzero(__radii)) v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d[sel]).sum() + (depths * v_depths[sel]).sum() From ee40bb21f295b7367f68b4cd06c8f8e7d4fcf643 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 15:25:39 -0700 Subject: [PATCH 44/66] change sel to pass tests --- tests/test_basic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 844876f75..0fa884c44 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -361,7 +361,6 @@ def test_fully_fused_projection_packed( ) sel = (__radii > 0) - print(torch.count_nonzero(_radii), torch.count_nonzero(__radii)) v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d[sel]).sum() + (depths * v_depths[sel]).sum() From 6f21cb7f38566abc18001271fb0354607c75b176 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 15:30:12 -0700 Subject: [PATCH 45/66] __sel --- tests/test_basic.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 0fa884c44..a68f6d60f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -335,7 +335,9 @@ def test_fully_fused_projection_packed( __compensations = torch.sparse_coo_tensor( torch.stack([camera_ids, gaussian_ids]), compensations, _compensations.shape ).to_dense() - sel = (__radii > 0) & (_radii > 0) + __sel = __radii > 0 + _sel = _radii > 0 + sel = (__sel) & (_sel) torch.testing.assert_close(__radii[sel], _radii[sel], rtol=0, atol=1) torch.testing.assert_close(__means2d[sel], _means2d[sel], rtol=1e-4, atol=1e-4) torch.testing.assert_close(__depths[sel], _depths[sel], rtol=1e-4, atol=1e-4) @@ -359,13 +361,11 @@ def test_fully_fused_projection_packed( (viewmats, quats, scales, means), retain_graph=True, ) - - sel = (__radii > 0) v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( - (means2d * v_means2d[sel]).sum() - + (depths * v_depths[sel]).sum() - + (normals * v_normals[sel]).sum() - + (conics * v_conics[sel]).sum(), + (means2d * v_means2d[__sel]).sum() + + (depths * v_depths[__sel]).sum() + + (normals * v_normals[__sel]).sum() + + (conics * v_conics[__sel]).sum(), (viewmats, quats, scales, means), retain_graph=True, ) From d2f89abc0e0acb35924a8ab30730d8b218c7d88d Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 15:37:48 -0700 Subject: [PATCH 46/66] cleanup --- gsplat/cuda/_torch_impl.py | 3 ++- tests/test_basic.py | 10 +++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index f6572bb93..d70a0ea52 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -151,8 +151,9 @@ def _fully_fused_projection( This is a minimal implementation of fully fused version, which has more arguments. Not all arguments are supported. """ - normals = _quat_to_rotmat(quats)[..., 2].repeat(viewmats.shape[0], 1, 1) covars, _ = _quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] + normals = _quat_to_rotmat(quats)[..., 2] # [N, 3] + normals = normals.repeat(viewmats.shape[0], 1, 1) # [C, N, 3] means_c, covars_c = _world_to_cam(means, covars, viewmats) means2d, covars2d = _persp_proj(means_c, covars_c, Ks, width, height) diff --git a/tests/test_basic.py b/tests/test_basic.py index a68f6d60f..8d2eda4fa 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -244,9 +244,9 @@ def test_projection(test_data, calc_compensations: bool): (viewmats, quats, scales, means), ) - torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-1) - torch.testing.assert_close(v_scales, _v_scales, rtol=2e-1, atol=2e-1) + torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-2) + torch.testing.assert_close(v_scales, _v_scales, rtol=1e-1, atol=2e-1) torch.testing.assert_close(v_means, _v_means, rtol=1e-2, atol=6e-2) @@ -421,8 +421,6 @@ def test_rasterize_to_pixels(test_data, channels: int): fully_fused_projection, isect_offset_encode, isect_tiles, - persp_proj, - quat_scale_to_covar_preci, rasterize_to_pixels, ) @@ -440,8 +438,6 @@ def test_rasterize_to_pixels(test_data, channels: int): colors = torch.randn(C, len(means), channels, device=device) backgrounds = torch.rand((C, colors.shape[-1]), device=device) - # covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) - # Project Gaussians to 2D radii, means2d, depths, normals, conics, compensations = fully_fused_projection( means, None, quats, scales, viewmats, Ks, width, height From 12723b19e31b1c2e1550082520ea588413143b55 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 16 Jul 2024 17:15:49 -0700 Subject: [PATCH 47/66] simplify utils --- .../cuda/csrc/fully_fused_projection_bwd.cu | 8 +++- .../csrc/fully_fused_projection_packed_bwd.cu | 16 +++++++- gsplat/cuda/csrc/utils.cuh | 38 ------------------- 3 files changed, 21 insertions(+), 41 deletions(-) diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index 2ef12278d..d75b56d12 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -152,9 +152,15 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + // add contribution from v_normals v_normals += idx * 3; - quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat + ); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index 734f41e9a..10e6a9256 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -145,9 +145,15 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + // add contribution from v_normals v_normals += idx * 3; - quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat + ); v_quats += idx * 4; v_scales += idx * 3; @@ -191,9 +197,15 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + // add contribution from v_normals v_normals += idx * 3; - quat_scale_to_covar_normal_vjp(quat, scale, rotmat, v_covar, glm::make_vec3(v_normals), v_quat, v_scale); + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat + ); warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index 6dd3784c3..7e23a0f8e 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -105,44 +105,6 @@ inline __device__ void quat_scale_to_covar_vjp( v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; } -template -inline __device__ void quat_scale_to_covar_normal_vjp( - // fwd inputs - const vec4 quat, const vec3 scale, - // precompute - const mat3 R, - // grad outputs - const mat3 v_covar, - const vec3 v_normal, - // grad inputs - vec4 &v_quat, vec3 &v_scale) { - T w = quat[0], x = quat[1], y = quat[2], z = quat[3]; - T sx = scale[0], sy = scale[1], sz = scale[2]; - - // M = R * S - mat3 S = mat3(sx, 0.f, 0.f, 0.f, sy, 0.f, 0.f, 0.f, sz); - mat3 M = R * S; - - // https://math.stackexchange.com/a/3850121 - // for D = W * X, G = df/dD - // df/dW = G * XT, df/dX = WT * G - // so - // for D = M * Mt, - // df/dM = df/dM + df/dMt = G * M + (Mt * G)t = G * M + Gt * M - mat3 v_M = (v_covar + glm::transpose(v_covar)) * M; - mat3 v_R = v_M * S; - - // add contribution from v_normal - v_R[2] += v_normal; - - // grad for (quat, scale) from covar - quat_to_rotmat_vjp(quat, v_R, v_quat); - - v_scale[0] += R[0][0] * v_M[0][0] + R[0][1] * v_M[0][1] + R[0][2] * v_M[0][2]; - v_scale[1] += R[1][0] * v_M[1][0] + R[1][1] * v_M[1][1] + R[1][2] * v_M[1][2]; - v_scale[2] += R[2][0] * v_M[2][0] + R[2][1] * v_M[2][1] + R[2][2] * v_M[2][2]; -} - template inline __device__ void quat_scale_to_preci_vjp( // fwd inputs From fadb7b409d79173d224c2f683d5ddb423ce90d37 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 19 Jul 2024 12:06:15 -0700 Subject: [PATCH 48/66] util --- gsplat/rendering.py | 4 ++-- gsplat/{utils => util}/camera_utils.py | 0 gsplat/{utils => util}/normal_utils.py | 0 gsplat/utils/__init__.py | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename gsplat/{utils => util}/camera_utils.py (100%) rename gsplat/{utils => util}/normal_utils.py (100%) delete mode 100644 gsplat/utils/__init__.py diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 49937472e..3dd52e13f 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -13,8 +13,8 @@ rasterize_to_pixels, spherical_harmonics, ) -from .utils.normal_utils import depth_to_normal -from .utils.camera_utils import getProjectionMatrix +from .util.normal_utils import depth_to_normal +from .util.camera_utils import getProjectionMatrix def rasterization( diff --git a/gsplat/utils/camera_utils.py b/gsplat/util/camera_utils.py similarity index 100% rename from gsplat/utils/camera_utils.py rename to gsplat/util/camera_utils.py diff --git a/gsplat/utils/normal_utils.py b/gsplat/util/normal_utils.py similarity index 100% rename from gsplat/utils/normal_utils.py rename to gsplat/util/normal_utils.py diff --git a/gsplat/utils/__init__.py b/gsplat/utils/__init__.py deleted file mode 100644 index e69de29bb..000000000 From f40b87910e30e905e87fbfdd15560344d1d29cc9 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 19 Jul 2024 12:29:42 -0700 Subject: [PATCH 49/66] test rasterization --- gsplat/rendering.py | 35 ++++++++++++++++++++++++++++++----- tests/test_rasterization.py | 2 +- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 3dd52e13f..48119c906 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -401,6 +401,25 @@ def rasterization( "tile_size": tile_size, "n_cameras": C, } + if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: + # normalize the accumulated depth to get the expected depth + render_colors[..., -1:] /= render_alphas.clamp(min=1e-10) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + render_colors[..., -1:], + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } + ) return render_colors, render_alphas, meta @@ -420,7 +439,7 @@ def _rasterization( sh_degree: Optional[int] = None, tile_size: int = 16, backgrounds: Optional[Tensor] = None, - render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"] = "RGB", rasterize_mode: Literal["classic", "antialiased"] = "classic", channel_chunk: int = 32, batch_per_iter: int = 100, @@ -454,7 +473,7 @@ def _rasterization( assert opacities.shape == (N,), opacities.shape assert viewmats.shape == (C, 4, 4), viewmats.shape assert Ks.shape == (C, 3, 3), Ks.shape - assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"], render_mode if sh_degree is None: # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] @@ -473,10 +492,10 @@ def _rasterization( # Project Gaussians to 2D. # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. - covars, _ = _quat_scale_to_covar_preci(quats, scales, True, False, triu=False) - radii, means2d, depths, conics, compensations = _fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = _fully_fused_projection( means, - covars, + quats, + scales, viewmats, Ks, width, @@ -544,6 +563,12 @@ def _rasterization( colors = depths[..., None] if backgrounds is not None: backgrounds = torch.zeros(C, 1, device=backgrounds.device) + elif render_mode in ["RGB+ED+N"]: + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + if backgrounds is not None: + backgrounds = torch.cat( + [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + ) else: # RGB pass if colors.shape[-1] > channel_chunk: diff --git a/tests/test_rasterization.py b/tests/test_rasterization.py index 57ea44bee..ce436352d 100644 --- a/tests/test_rasterization.py +++ b/tests/test_rasterization.py @@ -18,7 +18,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") @pytest.mark.parametrize("per_view_color", [True, False]) @pytest.mark.parametrize("sh_degree", [None, 3]) -@pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D"]) +@pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D", "RGB+ED+N"]) @pytest.mark.parametrize("packed", [True, False]) def test_rasterization( per_view_color: bool, sh_degree: Optional[int], render_mode: str, packed: bool From d766117b8575893a3897b4c6cf7f15026e8de73c Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 19 Jul 2024 12:43:32 -0700 Subject: [PATCH 50/66] benchmark script --- examples/benchmark_mcmc.sh | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/benchmark_mcmc.sh diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh new file mode 100644 index 000000000..a960e1abc --- /dev/null +++ b/examples/benchmark_mcmc.sh @@ -0,0 +1,33 @@ + + +SCENE_DIR="data/360_v2" +RESULTS_DIR="results/360_v2" +SCENE_LIST="garden bicycle stump treehill flowers bonsai counter kitchen room" +RENDER_TRAJ_PATH="ellipse" + + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + CAP_MAX=3000000 + MAX_STEPS=30000 + EVAL_STEPS="1000 7000 15000 30000" + SAVE_STEPS="15000 30000" + + python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --save_steps $SAVE_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + --init_type sfm \ + --cap_max $CAP_MAX \ + --max_steps $MAX_STEPS \ + --data_dir $SCENE_DIR/$SCENE/ \ + --render_traj_path $RENDER_TRAJ_PATH \ + --normal_consistency_loss \ + --result_dir $RESULTS_DIR/3dgs_normal/$SCENE/ + +done From 8e8d6816b13632e5d48fab0446a15781bc8ed840 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 19 Jul 2024 12:45:59 -0700 Subject: [PATCH 51/66] __init__ --- gsplat/util/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 gsplat/util/__init__.py diff --git a/gsplat/util/__init__.py b/gsplat/util/__init__.py new file mode 100644 index 000000000..e69de29bb From 24eb1b5fd5ca31c36a77ecc74352e8da305302b1 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 30 Aug 2024 21:42:58 -0700 Subject: [PATCH 52/66] fix normal consistency --- examples/benchmarks/mcmc.sh | 21 +++--- examples/simple_trainer.py | 72 ++++++++++++++++--- gsplat/cuda/csrc/bindings.h | 2 + .../cuda/csrc/fully_fused_projection_bwd.cu | 2 + .../cuda/csrc/fully_fused_projection_fwd.cu | 3 + .../csrc/fully_fused_projection_packed_bwd.cu | 2 + .../csrc/fully_fused_projection_packed_fwd.cu | 5 ++ gsplat/rendering.py | 19 +++++ 8 files changed, 106 insertions(+), 20 deletions(-) diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index 4b0add753..ae0735b4c 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -3,6 +3,8 @@ RESULT_DIR="results/benchmark_mcmc_1M" SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers CAP_MAX=1000000 +EVAL_STEPS="7000 30000" +SAVE_STEPS="7000 30000" for SCENE in $SCENE_LIST; do @@ -15,20 +17,21 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps $EVAL_STEPS --save_steps $SAVE_STEPS --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ + --normal_consistency_loss \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done + # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + # do + # CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + # --strategy.cap-max $CAP_MAX \ + # --data_dir $SCENE_DIR/$SCENE/ \ + # --result_dir $RESULT_DIR/$SCENE/ \ + # --ckpt $CKPT + # done done diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a6c8f65dc..fec87ab5b 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -130,6 +130,13 @@ class Config: depth_loss: bool = False # Weight for depth loss depth_lambda: float = 1e-2 + + # Enable normal consistency loss. (experimental) + normal_consistency_loss: bool = False + # Weight for normal consistency loss + normal_consistency_lambda: float = 0.05 + # Start applying normal consistency loss after this iteration + normal_consistency_start_iter: int = 7000 # Dump information to tensorboard every this steps tb_every: int = 100 @@ -249,6 +256,12 @@ def __init__( self.local_rank = local_rank self.world_size = world_size self.device = f"cuda:{local_rank}" + + self.render_mode = "RGB" + if cfg.depth_loss: + self.render_mode = "RGB+ED" + if cfg.normal_consistency_loss: + self.render_mode = "RGB+ED+N" # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -527,12 +540,9 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode=self.render_mode, ) - if renders.shape[-1] == 4: - colors, depths = renders[..., 0:3], renders[..., 3:4] - else: - colors, depths = renders, None + colors = renders[..., :3] if cfg.random_bkgd: bkgd = torch.rand(1, 3, device=device) @@ -553,6 +563,7 @@ def train(self): ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda if cfg.depth_loss: + depths = renders[..., -1:] # query depths from depth map points = torch.stack( [ @@ -571,6 +582,14 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + normalconsistencyloss = ( + 1 - (normals_rend * normals_surf).sum(dim=-1) + ).mean() + if step > cfg.normal_consistency_start_iter: + loss += normalconsistencyloss * cfg.normal_consistency_lambda # regularizations if cfg.opacity_reg > 0.0: @@ -614,6 +633,12 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.normal_consistency_loss: + self.writer.add_scalar( + "train/normalconsistencyloss", + normalconsistencyloss.item(), + step, + ) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) @@ -740,7 +765,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -748,16 +773,28 @@ def eval(self, step: int, stage: str = "val"): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, + render_mode=self.render_mode, ) # [1, H, W, 3] - colors = torch.clamp(colors, 0.0, 1.0) torch.cuda.synchronize() ellipse_time += time.time() - tic + + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + canvas_list = [pixels, colors] + if cfg.depth_loss: + depths = renders[..., -1:] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(depths) + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) if world_rank == 0: # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() imageio.imwrite( - f"{self.render_dir}/{stage}_{i:04d}.png", + f"{self.render_dir}/{stage}_step{step:04d}_{i:04d}.png", (canvas * 255).astype(np.uint8), ) @@ -816,7 +853,7 @@ def render_traj(self, step: int): canvas_all = [] for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): - renders, _, _ = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds[i : i + 1], Ks=K[None], width=width, @@ -824,8 +861,21 @@ def render_traj(self, step: int): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - render_mode="RGB+ED", + render_mode=self.render_mode, ) # [1, H, W, 4] + + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + canvas_list = [colors] + if cfg.depth_loss: + depths = renders[..., -1:] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(depths) + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] depths = renders[0, ..., 3:4] # [H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 5d4d772ba..65f726735 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -89,6 +89,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -252,6 +253,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index e80673af1..7c5207f50 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -292,6 +292,7 @@ fully_fused_projection_bwd_tensor( GSPLAT_CHECK_INPUT(conics); GSPLAT_CHECK_INPUT(v_means2d); GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); GSPLAT_CHECK_INPUT(v_conics); if (compensations.has_value()) { GSPLAT_CHECK_INPUT(compensations.value()); @@ -342,6 +343,7 @@ fully_fused_projection_bwd_tensor( : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index a4cf3ac0b..9a11bc1c8 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -37,6 +37,7 @@ __global__ void fully_fused_projection_fwd_kernel( int32_t *__restrict__ radii, // [C, N] T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] + T *__restrict__ normals, // [C, N, 3] T *__restrict__ conics, // [C, N, 3] T *__restrict__ compensations // [C, N] optional ) { @@ -187,6 +188,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -255,6 +257,7 @@ fully_fused_projection_fwd_tensor( radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr ); diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index d6637b59b..b2267c7f3 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -343,6 +343,7 @@ fully_fused_projection_packed_bwd_tensor( GSPLAT_CHECK_INPUT(conics); GSPLAT_CHECK_INPUT(v_means2d); GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); GSPLAT_CHECK_INPUT(v_conics); if (compensations.has_value()) { GSPLAT_CHECK_INPUT(compensations.value()); @@ -408,6 +409,7 @@ fully_fused_projection_packed_bwd_tensor( : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 38c13b54a..27bc4c9a3 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -44,6 +44,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( int32_t *__restrict__ radii, // [nnz] T *__restrict__ means2d, // [nnz, 2] T *__restrict__ depths, // [nnz] + T *__restrict__ normals, // [nnz, 3] T *__restrict__ conics, // [nnz, 3] T *__restrict__ compensations // [nnz] optional ) { @@ -246,6 +247,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -319,6 +321,7 @@ fully_fused_projection_packed_fwd_tensor( nullptr, nullptr, nullptr, + nullptr, nullptr ); block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); @@ -369,6 +372,7 @@ fully_fused_projection_packed_fwd_tensor( radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr ); @@ -383,6 +387,7 @@ fully_fused_projection_packed_fwd_tensor( radii, means2d, depths, + normals, conics, compensations ); diff --git a/gsplat/rendering.py b/gsplat/rendering.py index e987d8b40..040ad8f4f 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -557,6 +557,25 @@ def rasterization( absgrad=absgrad, ) + if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: + # normalize the accumulated depth to get the expected depth + render_colors[..., -1:] /= render_alphas.clamp(min=1e-10) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + render_colors[..., -1:], + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } + ) return render_colors, render_alphas, meta From a86ef37f67ba86ecbab2c193ff4b32c017cad911 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 11:52:28 -0700 Subject: [PATCH 53/66] ellipse --- examples/benchmarks/mcmc.sh | 21 ++++++++++-------- examples/datasets/traj.py | 2 +- examples/simple_trainer.py | 44 +++++++++++++++++++++++++++---------- 3 files changed, 45 insertions(+), 22 deletions(-) diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index 4b0add753..d960ce76d 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark_mcmc_1M" SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" CAP_MAX=1000000 @@ -15,20 +16,22 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render - for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - do - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - --strategy.cap-max $CAP_MAX \ - --data_dir $SCENE_DIR/$SCENE/ \ - --result_dir $RESULT_DIR/$SCENE/ \ - --ckpt $CKPT - done + # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + # do + # CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + # --strategy.cap-max $CAP_MAX \ + # --render_traj_path $RENDER_TRAJ_PATH \ + # --data_dir $SCENE_DIR/$SCENE/ \ + # --result_dir $RESULT_DIR/$SCENE/ \ + # --ckpt $CKPT + # done done diff --git a/examples/datasets/traj.py b/examples/datasets/traj.py index 8d49aa711..8fcc981b2 100644 --- a/examples/datasets/traj.py +++ b/examples/datasets/traj.py @@ -90,7 +90,7 @@ def get_positions(theta): ind_up = np.argmax(np.abs(avg_up)) up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up]) - return np.stack([viewmatrix(p - center, up, p) for p in positions]) + return np.stack([viewmatrix(center - p, up, p) for p in positions]) def generate_ellipse_path_y( diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a6c8f65dc..53c6e9ea3 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -15,7 +15,7 @@ import viser import yaml from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path +from datasets.traj import generate_interpolated_path, generate_ellipse_path_z from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -38,6 +38,8 @@ class Config: ckpt: Optional[List[str]] = None # Name of compression strategy to use compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset data_dir: str = "data/360_v2/garden" @@ -63,7 +65,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) @@ -757,7 +759,7 @@ def eval(self, step: int, stage: str = "val"): # write images canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() imageio.imwrite( - f"{self.render_dir}/{stage}_{i:04d}.png", + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", (canvas * 255).astype(np.uint8), ) @@ -800,25 +802,43 @@ def render_traj(self, step: int): cfg = self.cfg device = self.device - camtoworlds = self.parser.camtoworlds[5:-5] - camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] - camtoworlds = np.concatenate( + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( [ - camtoworlds, - np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), ], axis=1, ) # [N, 4, 4] - camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) width, height = list(self.parser.imsize_dict.values())[0] canvas_all = [] - for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + renders, _, _ = self.rasterize_splats( - camtoworlds=camtoworlds[i : i + 1], - Ks=K[None], + camtoworlds=camtoworlds, + Ks=Ks, width=width, height=height, sh_degree=cfg.sh_degree, From f0b93f045debc09c2c36fe9305f025c1345b6b33 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 12:44:29 -0700 Subject: [PATCH 54/66] cleanup --- examples/benchmarks/basic.sh | 3 +++ examples/benchmarks/mcmc.sh | 20 ++++++++++---------- examples/simple_trainer.py | 2 +- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/benchmarks/basic.sh b/examples/benchmarks/basic.sh index 6a0986aa3..6b043c567 100644 --- a/examples/benchmarks/basic.sh +++ b/examples/benchmarks/basic.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark" SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" for SCENE in $SCENE_LIST; do @@ -14,6 +15,7 @@ do # train without eval CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ @@ -21,6 +23,7 @@ do for CKPT in $RESULT_DIR/$SCENE/ckpts/*; do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ --ckpt $CKPT diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index d960ce76d..23e40838d 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -16,22 +16,22 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir data/360_v2/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render - # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - # do - # CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - # --strategy.cap-max $CAP_MAX \ - # --render_traj_path $RENDER_TRAJ_PATH \ - # --data_dir $SCENE_DIR/$SCENE/ \ - # --result_dir $RESULT_DIR/$SCENE/ \ - # --ckpt $CKPT - # done + for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + do + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --strategy.cap-max $CAP_MAX \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --ckpt $CKPT + done done diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 53c6e9ea3..8eee54e4e 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -65,7 +65,7 @@ class Config: # Number of training steps max_steps: int = 30_000 # Steps to evaluate the model - eval_steps: List[int] = field(default_factory=lambda: [1_000, 7_000, 30_000]) + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) From a014c893daf2192e5633359a63ba70644729ad53 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 13:23:05 -0700 Subject: [PATCH 55/66] uncomment --- examples/simple_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 8eee54e4e..a5e09d4bc 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -923,7 +923,7 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) step = ckpts[0]["step"] runner.eval(step=step) - # runner.render_traj(step=step) + runner.render_traj(step=step) if cfg.compression is not None: runner.run_compression(step=step) else: From 35b21c4e7d6998d3ef531070e30f5838f30b9186 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 14:19:15 -0700 Subject: [PATCH 56/66] canvas list --- examples/simple_trainer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index a5e09d4bc..ad4b3d353 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -751,16 +751,19 @@ def eval(self, step: int, stage: str = "val"): near_plane=cfg.near_plane, far_plane=cfg.far_plane, ) # [1, H, W, 3] - colors = torch.clamp(colors, 0.0, 1.0) torch.cuda.synchronize() ellipse_time += time.time() - tic + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + if world_rank == 0: # write images - canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) imageio.imwrite( f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", - (canvas * 255).astype(np.uint8), + canvas, ) pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] @@ -846,15 +849,14 @@ def render_traj(self, step: int): far_plane=cfg.far_plane, render_mode="RGB+ED", ) # [1, H, W, 4] - colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] - depths = renders[0, ..., 3:4] # [H, W, 1] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] # write images - canvas = torch.cat( - [colors, depths.repeat(1, 1, 3)], dim=0 if width > height else 1 - ) - canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) canvas_all.append(canvas) # save to video From f072f17b8f02de39b8d7fa2a7e91bf93450cb24f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 14:42:38 -0700 Subject: [PATCH 57/66] cleanup --- examples/benchmark_mcmc.sh | 33 -------- gsplat/rendering.py | 143 ---------------------------------- gsplat/rendering_inria.py | 152 +++++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 176 deletions(-) delete mode 100644 examples/benchmark_mcmc.sh create mode 100644 gsplat/rendering_inria.py diff --git a/examples/benchmark_mcmc.sh b/examples/benchmark_mcmc.sh deleted file mode 100644 index a960e1abc..000000000 --- a/examples/benchmark_mcmc.sh +++ /dev/null @@ -1,33 +0,0 @@ - - -SCENE_DIR="data/360_v2" -RESULTS_DIR="results/360_v2" -SCENE_LIST="garden bicycle stump treehill flowers bonsai counter kitchen room" -RENDER_TRAJ_PATH="ellipse" - - -for SCENE in $SCENE_LIST; -do - echo "Running $SCENE" - - if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then - DATA_FACTOR=2 - else - DATA_FACTOR=4 - fi - - CAP_MAX=3000000 - MAX_STEPS=30000 - EVAL_STEPS="1000 7000 15000 30000" - SAVE_STEPS="15000 30000" - - python simple_trainer_mcmc.py --eval_steps $EVAL_STEPS --save_steps $SAVE_STEPS --disable_viewer --data_factor $DATA_FACTOR \ - --init_type sfm \ - --cap_max $CAP_MAX \ - --max_steps $MAX_STEPS \ - --data_dir $SCENE_DIR/$SCENE/ \ - --render_traj_path $RENDER_TRAJ_PATH \ - --normal_consistency_loss \ - --result_dir $RESULTS_DIR/3dgs_normal/$SCENE/ - -done diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 0ed1b2151..3ef1b98a4 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1031,146 +1031,3 @@ def rasterization_inria_wrapper( render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) return render_colors, None, {} - - -def rasterization_2dgs_inria_wrapper( - means: Tensor, # [N, 3] - quats: Tensor, # [N, 4] - scales: Tensor, # [N, 3] - opacities: Tensor, # [N] - colors: Tensor, # [N, D] or [N, K, 3] - viewmats: Tensor, # [C, 4, 4] - Ks: Tensor, # [C, 3, 3] - width: int, - height: int, - near_plane: float = 0.01, - far_plane: float = 100.0, - eps2d: float = 0.3, - sh_degree: Optional[int] = None, - backgrounds: Optional[Tensor] = None, - **kwargs, -) -> Tuple[Tensor, Tensor, Dict]: - """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. - - Install the 2DGS rasterization backend from - https://github.com/hbb1/diff-surfel-rasterization - """ - from diff_surfel_rasterization import ( - GaussianRasterizationSettings, - GaussianRasterizer, - ) - - assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" - C = len(viewmats) - device = means.device - channels = colors.shape[-1] - - # rasterization from inria does not do normalization internally - quats = F.normalize(quats, dim=-1) # [N, 4] - scales = scales[:, :2] # [N, 2] - - render_colors = [] - for cid in range(C): - FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) - FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) - tanfovx = math.tan(FoVx * 0.5) - tanfovy = math.tan(FoVy * 0.5) - - world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = getProjectionMatrix( - znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device - ).transpose(0, 1) - full_proj_transform = ( - world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) - ).squeeze(0) - camera_center = world_view_transform.inverse()[3, :3] - - background = ( - backgrounds[cid] - if backgrounds is not None - else torch.zeros(3, device=device) - ) - - raster_settings = GaussianRasterizationSettings( - image_height=height, - image_width=width, - tanfovx=tanfovx, - tanfovy=tanfovy, - bg=background, - scale_modifier=1.0, - viewmatrix=world_view_transform, - projmatrix=full_proj_transform, - sh_degree=0 if sh_degree is None else sh_degree, - campos=camera_center, - prefiltered=False, - debug=False, - ) - - rasterizer = GaussianRasterizer(raster_settings=raster_settings) - - means2D = torch.zeros_like(means, requires_grad=True, device=device) - - render_colors_ = [] - for i in range(0, channels, 3): - _colors = colors[..., i : i + 3] - if _colors.shape[-1] < 3: - pad = torch.zeros( - _colors.shape[0], 3 - _colors.shape[-1], device=device - ) - _colors = torch.cat([_colors, pad], dim=-1) - _render_colors_, _, allmap = rasterizer( - means3D=means, - means2D=means2D, - shs=_colors if colors.dim() == 3 else None, - colors_precomp=_colors if colors.dim() == 2 else None, - opacities=opacities[:, None], - scales=scales, - rotations=quats, - cov3D_precomp=None, - ) - if _colors.shape[-1] < 3: - _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] - render_colors_.append(_render_colors_) - render_colors_ = torch.cat(render_colors_, dim=-1) - - render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] - render_colors.append(render_colors_) - render_colors = torch.stack(render_colors, dim=0) - - # additional maps - allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] - render_depth_expected = allmap[..., 0:1] - render_alphas = allmap[..., 1:2] - render_normal = allmap[..., 2:5] - render_depth_median = allmap[..., 5:6] - render_dist = allmap[..., 6:7] - - render_normal = render_normal @ (world_view_transform[:3, :3].T) - render_depth_expected = render_depth_expected / render_alphas - render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) - render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) - - # render_depth is either median or expected by setting depth_ratio to 1 or 0 - # for bounded scene, use median depth, i.e., depth_ratio = 1; - # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. - depth_ratio = 0 - render_depth = ( - render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median - ) - - normals_surf = depth_to_normal( - render_depth, - viewmats, - Ks, - near_plane=near_plane, - far_plane=far_plane, - ) - normals_surf = normals_surf * (render_alphas).detach() - - render_colors = torch.cat([render_colors, render_depth], dim=-1) - meta = { - "normals_rend": render_normal, - "normals_surf": normals_surf, - "render_distloss": render_dist, - } - return render_colors, render_alphas, meta diff --git a/gsplat/rendering_inria.py b/gsplat/rendering_inria.py new file mode 100644 index 000000000..4a2ab65c8 --- /dev/null +++ b/gsplat/rendering_inria.py @@ -0,0 +1,152 @@ +import math +from typing import Dict, Optional, Tuple + +import torch +from torch import Tensor +import torch.nn.functional as F + +from .util.normal_utils import depth_to_normal +from .util.camera_utils import getProjectionMatrix + + +def rasterization_2dgs_inria_wrapper( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + colors: Tensor, # [N, D] or [N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 100.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + backgrounds: Optional[Tensor] = None, + **kwargs, +) -> Tuple[Tensor, Tensor, Dict]: + """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. + + Install the 2DGS rasterization backend from + https://github.com/hbb1/diff-surfel-rasterization + """ + from diff_surfel_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) + + assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" + C = len(viewmats) + device = means.device + channels = colors.shape[-1] + + # rasterization from inria does not do normalization internally + quats = F.normalize(quats, dim=-1) # [N, 4] + scales = scales[:, :2] # [N, 2] + + render_colors = [] + for cid in range(C): + FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) + FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) + tanfovx = math.tan(FoVx * 0.5) + tanfovy = math.tan(FoVy * 0.5) + + world_view_transform = viewmats[cid].transpose(0, 1) + projection_matrix = getProjectionMatrix( + znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device + ).transpose(0, 1) + full_proj_transform = ( + world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) + ).squeeze(0) + camera_center = world_view_transform.inverse()[3, :3] + + background = ( + backgrounds[cid] + if backgrounds is not None + else torch.zeros(3, device=device) + ) + + raster_settings = GaussianRasterizationSettings( + image_height=height, + image_width=width, + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=background, + scale_modifier=1.0, + viewmatrix=world_view_transform, + projmatrix=full_proj_transform, + sh_degree=0 if sh_degree is None else sh_degree, + campos=camera_center, + prefiltered=False, + debug=False, + ) + + rasterizer = GaussianRasterizer(raster_settings=raster_settings) + + means2D = torch.zeros_like(means, requires_grad=True, device=device) + + render_colors_ = [] + for i in range(0, channels, 3): + _colors = colors[..., i : i + 3] + if _colors.shape[-1] < 3: + pad = torch.zeros( + _colors.shape[0], 3 - _colors.shape[-1], device=device + ) + _colors = torch.cat([_colors, pad], dim=-1) + _render_colors_, _, allmap = rasterizer( + means3D=means, + means2D=means2D, + shs=_colors if colors.dim() == 3 else None, + colors_precomp=_colors if colors.dim() == 2 else None, + opacities=opacities[:, None], + scales=scales, + rotations=quats, + cov3D_precomp=None, + ) + if _colors.shape[-1] < 3: + _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] + render_colors_.append(_render_colors_) + render_colors_ = torch.cat(render_colors_, dim=-1) + + render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] + render_colors.append(render_colors_) + render_colors = torch.stack(render_colors, dim=0) + + # additional maps + allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] + render_depth_expected = allmap[..., 0:1] + render_alphas = allmap[..., 1:2] + render_normal = allmap[..., 2:5] + render_depth_median = allmap[..., 5:6] + render_dist = allmap[..., 6:7] + + render_normal = render_normal @ (world_view_transform[:3, :3].T) + render_depth_expected = render_depth_expected / render_alphas + render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) + render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) + + # render_depth is either median or expected by setting depth_ratio to 1 or 0 + # for bounded scene, use median depth, i.e., depth_ratio = 1; + # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. + depth_ratio = 0 + render_depth = ( + render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median + ) + + normals_surf = depth_to_normal( + render_depth, + viewmats, + Ks, + near_plane=near_plane, + far_plane=far_plane, + ) + normals_surf = normals_surf * (render_alphas).detach() + + render_colors = torch.cat([render_colors, render_depth], dim=-1) + meta = { + "normals_rend": render_normal, + "normals_surf": normals_surf, + "render_distloss": render_dist, + } + return render_colors, render_alphas, meta From 1fa189cd71ef3c55de94bd52ebf723e8e9d9722f Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 15:07:07 -0700 Subject: [PATCH 58/66] fix tests --- tests/test_basic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 49c3b2b0d..2eb63a56a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -166,7 +166,8 @@ def test_proj(test_data, ortho: bool): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") @pytest.mark.parametrize("calc_compensations", [False, True]) -def test_projection(test_data, calc_compensations: bool): +@pytest.mark.parametrize("ortho", [True, False]) +def test_projection(test_data, calc_compensations: bool, ortho: bool): from gsplat.cuda._torch_impl import _fully_fused_projection, _quat_to_rotmat from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci @@ -196,6 +197,7 @@ def test_projection(test_data, calc_compensations: bool): width, height, calc_compensations=calc_compensations, + ortho=ortho, ) ( _radii, @@ -263,7 +265,7 @@ def test_projection(test_data, calc_compensations: bool): @pytest.mark.parametrize("calc_compensations", [False, True]) @pytest.mark.parametrize("ortho", [True, False]) def test_fully_fused_projection_packed( - test_data, sparse_grad: bool, calc_compensations: bool + test_data, sparse_grad: bool, calc_compensations: bool, ortho: bool ): from gsplat.cuda._wrapper import fully_fused_projection From b9ef876fb8d5c6a35e0e5cc3375896564e2dac80 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 15:12:39 -0700 Subject: [PATCH 59/66] remove 2dgs inria --- gsplat/rendering_inria.py | 152 -------------------------------------- 1 file changed, 152 deletions(-) delete mode 100644 gsplat/rendering_inria.py diff --git a/gsplat/rendering_inria.py b/gsplat/rendering_inria.py deleted file mode 100644 index 4a2ab65c8..000000000 --- a/gsplat/rendering_inria.py +++ /dev/null @@ -1,152 +0,0 @@ -import math -from typing import Dict, Optional, Tuple - -import torch -from torch import Tensor -import torch.nn.functional as F - -from .util.normal_utils import depth_to_normal -from .util.camera_utils import getProjectionMatrix - - -def rasterization_2dgs_inria_wrapper( - means: Tensor, # [N, 3] - quats: Tensor, # [N, 4] - scales: Tensor, # [N, 3] - opacities: Tensor, # [N] - colors: Tensor, # [N, D] or [N, K, 3] - viewmats: Tensor, # [C, 4, 4] - Ks: Tensor, # [C, 3, 3] - width: int, - height: int, - near_plane: float = 0.01, - far_plane: float = 100.0, - eps2d: float = 0.3, - sh_degree: Optional[int] = None, - backgrounds: Optional[Tensor] = None, - **kwargs, -) -> Tuple[Tensor, Tensor, Dict]: - """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. - - Install the 2DGS rasterization backend from - https://github.com/hbb1/diff-surfel-rasterization - """ - from diff_surfel_rasterization import ( - GaussianRasterizationSettings, - GaussianRasterizer, - ) - - assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" - C = len(viewmats) - device = means.device - channels = colors.shape[-1] - - # rasterization from inria does not do normalization internally - quats = F.normalize(quats, dim=-1) # [N, 4] - scales = scales[:, :2] # [N, 2] - - render_colors = [] - for cid in range(C): - FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) - FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) - tanfovx = math.tan(FoVx * 0.5) - tanfovy = math.tan(FoVy * 0.5) - - world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = getProjectionMatrix( - znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device - ).transpose(0, 1) - full_proj_transform = ( - world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) - ).squeeze(0) - camera_center = world_view_transform.inverse()[3, :3] - - background = ( - backgrounds[cid] - if backgrounds is not None - else torch.zeros(3, device=device) - ) - - raster_settings = GaussianRasterizationSettings( - image_height=height, - image_width=width, - tanfovx=tanfovx, - tanfovy=tanfovy, - bg=background, - scale_modifier=1.0, - viewmatrix=world_view_transform, - projmatrix=full_proj_transform, - sh_degree=0 if sh_degree is None else sh_degree, - campos=camera_center, - prefiltered=False, - debug=False, - ) - - rasterizer = GaussianRasterizer(raster_settings=raster_settings) - - means2D = torch.zeros_like(means, requires_grad=True, device=device) - - render_colors_ = [] - for i in range(0, channels, 3): - _colors = colors[..., i : i + 3] - if _colors.shape[-1] < 3: - pad = torch.zeros( - _colors.shape[0], 3 - _colors.shape[-1], device=device - ) - _colors = torch.cat([_colors, pad], dim=-1) - _render_colors_, _, allmap = rasterizer( - means3D=means, - means2D=means2D, - shs=_colors if colors.dim() == 3 else None, - colors_precomp=_colors if colors.dim() == 2 else None, - opacities=opacities[:, None], - scales=scales, - rotations=quats, - cov3D_precomp=None, - ) - if _colors.shape[-1] < 3: - _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] - render_colors_.append(_render_colors_) - render_colors_ = torch.cat(render_colors_, dim=-1) - - render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] - render_colors.append(render_colors_) - render_colors = torch.stack(render_colors, dim=0) - - # additional maps - allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] - render_depth_expected = allmap[..., 0:1] - render_alphas = allmap[..., 1:2] - render_normal = allmap[..., 2:5] - render_depth_median = allmap[..., 5:6] - render_dist = allmap[..., 6:7] - - render_normal = render_normal @ (world_view_transform[:3, :3].T) - render_depth_expected = render_depth_expected / render_alphas - render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) - render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) - - # render_depth is either median or expected by setting depth_ratio to 1 or 0 - # for bounded scene, use median depth, i.e., depth_ratio = 1; - # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. - depth_ratio = 0 - render_depth = ( - render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median - ) - - normals_surf = depth_to_normal( - render_depth, - viewmats, - Ks, - near_plane=near_plane, - far_plane=far_plane, - ) - normals_surf = normals_surf * (render_alphas).detach() - - render_colors = torch.cat([render_colors, render_depth], dim=-1) - meta = { - "normals_rend": render_normal, - "normals_surf": normals_surf, - "render_distloss": render_dist, - } - return render_colors, render_alphas, meta From c08a23ea8738bee095cc75e65edfb786aaafd9df Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Sat, 31 Aug 2024 20:26:55 -0700 Subject: [PATCH 60/66] script --- examples/benchmarks/mcmc.sh | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index a565819f3..4ae90d00c 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -4,8 +4,6 @@ SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers RENDER_TRAJ_PATH="ellipse" CAP_MAX=1000000 -EVAL_STEPS="7000 30000" -SAVE_STEPS="7000 30000" for SCENE in $SCENE_LIST; do @@ -18,7 +16,7 @@ do echo "Running $SCENE" # train without eval - CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps $EVAL_STEPS --save_steps $SAVE_STEPS --disable_viewer --data_factor $DATA_FACTOR \ + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ --normal_consistency_loss \ --render_traj_path $RENDER_TRAJ_PATH \ @@ -26,15 +24,16 @@ do --result_dir $RESULT_DIR/$SCENE/ # run eval and render - # for CKPT in $RESULT_DIR/$SCENE/ckpts/*; - # do - # CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ - # --strategy.cap-max $CAP_MAX \ - # --render_traj_path $RENDER_TRAJ_PATH \ - # --data_dir $SCENE_DIR/$SCENE/ \ - # --result_dir $RESULT_DIR/$SCENE/ \ - # --ckpt $CKPT - # done + for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + do + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --strategy.cap-max $CAP_MAX \ + --normal_consistency_loss \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --ckpt $CKPT + done done From bfd78bc3125832f92bc140beda925d47ce81f76c Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 13 Sep 2024 13:11:50 -0700 Subject: [PATCH 61/66] fix merge --- examples/benchmarks/mcmc.sh | 4 +--- examples/benchmarks/normal/mcmc_normal.sh | 22 ++++++++++++++++++++++ gsplat/cuda/_torch_impl.py | 17 +++++++++++++---- 3 files changed, 36 insertions(+), 7 deletions(-) create mode 100644 examples/benchmarks/normal/mcmc_normal.sh diff --git a/examples/benchmarks/mcmc.sh b/examples/benchmarks/mcmc.sh index 4ae90d00c..0eaa5c8bb 100644 --- a/examples/benchmarks/mcmc.sh +++ b/examples/benchmarks/mcmc.sh @@ -18,9 +18,8 @@ do # train without eval CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ - --normal_consistency_loss \ --render_traj_path $RENDER_TRAJ_PATH \ - --data_dir data/360_v2/$SCENE/ \ + --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ # run eval and render @@ -28,7 +27,6 @@ do do CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ --strategy.cap-max $CAP_MAX \ - --normal_consistency_loss \ --render_traj_path $RENDER_TRAJ_PATH \ --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ \ diff --git a/examples/benchmarks/normal/mcmc_normal.sh b/examples/benchmarks/normal/mcmc_normal.sh new file mode 100644 index 000000000..1de152f34 --- /dev/null +++ b/examples/benchmarks/normal/mcmc_normal.sh @@ -0,0 +1,22 @@ +SCENE_DIR="data/360_v2" +SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers + +RESULT_DIR="results/benchmark_normal" +RENDER_TRAJ_PATH="ellipse" + +for SCENE in $SCENE_LIST; +do + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + echo "Running $SCENE" + + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --normal_consistency_loss \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index a16368465..6f5c4dce2 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -7,6 +7,7 @@ def _quat_to_rotmat(quats: Tensor) -> Tensor: + """Convert quaternion to rotation matrix.""" quats = F.normalize(quats, p=2, dim=-1) w, x, y, z = torch.unbind(quats, dim=-1) R = torch.stack( @@ -23,8 +24,17 @@ def _quat_to_rotmat(quats: Tensor) -> Tensor: ], dim=-1, ) - R = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3) - return R + return R.reshape(quats.shape[:-1] + (3, 3)) + + +def _quat_scale_to_matrix( + quats: Tensor, # [N, 4], + scales: Tensor, # [N, 3], +) -> Tensor: + """Convert quaternion and scale to a 3x3 matrix (R * S).""" + R = _quat_to_rotmat(quats) # (..., 3, 3) + M = R * scales[..., None, :] # (..., 3, 3) + return M def _quat_scale_to_covar_preci( @@ -35,8 +45,7 @@ def _quat_scale_to_covar_preci( triu: bool = False, ) -> Tuple[Optional[Tensor], Optional[Tensor]]: """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`.""" - R = _quat_to_rotmat(quats) - # R.register_hook(lambda grad: print("grad R", grad)) + R = _quat_to_rotmat(quats) # (..., 3, 3) if compute_covar: M = R * scales[..., None, :] # (..., 3, 3) From 0a1dc09752c68bc6a135537c4bcf331547f67c90 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 13 Sep 2024 13:52:21 -0700 Subject: [PATCH 62/66] fix utils --- examples/simple_trainer.py | 3 -- gsplat/cuda/_wrapper.py | 4 +-- gsplat/rendering.py | 10 ++---- gsplat/util/__init__.py | 0 gsplat/util/camera_utils.py | 25 ------------- gsplat/util/normal_utils.py | 70 ------------------------------------- tests/test_basic.py | 9 +++-- 7 files changed, 10 insertions(+), 111 deletions(-) delete mode 100644 gsplat/util/__init__.py delete mode 100644 gsplat/util/camera_utils.py delete mode 100644 gsplat/util/normal_utils.py diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 439795465..6d7910d20 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -860,9 +860,6 @@ def eval(self, step: int, stage: str = "val"): canvas_list.extend([normals_rend * 0.5 + 0.5]) canvas_list.extend([normals_surf * 0.5 + 0.5]) - colors = torch.clamp(colors, 0.0, 1.0) - canvas_list = [pixels, colors] - if world_rank == 0: # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 8a922e241..c2500d6d5 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -188,7 +188,7 @@ def fully_fused_projection( sparse_grad: bool = False, calc_compensations: bool = False, ortho: bool = False, -) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Projects Gaussians to 2D. This function fuse the process of computing covariances @@ -754,7 +754,7 @@ def forward( radius_clip: float, calc_compensations: bool, ortho: bool, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: # "covars" and {"quats", "scales"} are mutually exclusive radii, means2d, depths, normals, conics, compensations = _make_lazy_cuda_func( "fully_fused_projection_fwd" diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 692a5dada..e8b04dead 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -17,8 +17,6 @@ rasterize_to_pixels_2dgs, spherical_harmonics, ) -from .util.normal_utils import depth_to_normal -from .util.camera_utils import getProjectionMatrix from .distributed import ( all_gather_int32, all_gather_tensor_list, @@ -587,10 +585,8 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso normals_rend = render_colors[..., -4:-1] normals_surf = depth_to_normal( render_colors[..., -1:], - viewmats, + camtoworlds, Ks, - near_plane=near_plane, - far_plane=far_plane, ) normals_surf = normals_surf * (render_alphas).detach() meta.update( @@ -828,10 +824,8 @@ def _rasterization( normals_rend = render_colors[..., -4:-1] normals_surf = depth_to_normal( render_colors[..., -1:], - viewmats, + camtoworlds, Ks, - near_plane=near_plane, - far_plane=far_plane, ) normals_surf = normals_surf * (render_alphas).detach() meta.update( diff --git a/gsplat/util/__init__.py b/gsplat/util/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/gsplat/util/camera_utils.py b/gsplat/util/camera_utils.py deleted file mode 100644 index 924e87115..000000000 --- a/gsplat/util/camera_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import math -import torch - - -def getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) - - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right - - P = torch.zeros(4, 4, device=device) - - z_sign = 1.0 - - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P diff --git a/gsplat/util/normal_utils.py b/gsplat/util/normal_utils.py deleted file mode 100644 index 9ca55c8c9..000000000 --- a/gsplat/util/normal_utils.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch -import torch.nn.functional as F -import math -from torch import Tensor - -from .camera_utils import getProjectionMatrix - - -def _depths_to_points(depthmap, world_view_transform, full_proj_transform): - c2w = (world_view_transform.T).inverse() - H, W = depthmap.shape[:2] - ndc2pix = ( - torch.tensor([[W / 2, 0, 0, (W) / 2], [0, H / 2, 0, (H) / 2], [0, 0, 0, 1]]) - .float() - .cuda() - .T - ) - projection_matrix = c2w.T @ full_proj_transform - intrins = (projection_matrix @ ndc2pix)[:3, :3].T - - grid_x, grid_y = torch.meshgrid( - torch.arange(W, device="cuda").float(), - torch.arange(H, device="cuda").float(), - indexing="xy", - ) - points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( - -1, 3 - ) - rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T - rays_o = c2w[:3, 3] - points = depthmap.reshape(-1, 1) * rays_d + rays_o - return points - - -def _depth_to_normal(depth, world_view_transform, full_proj_transform): - points = _depths_to_points( - depth, world_view_transform, full_proj_transform - ).reshape(*depth.shape[:2], 3) - output = torch.zeros_like(points) - dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) - dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) - normal_map = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) - output[1:-1, 1:-1, :] = normal_map - return output - - -def depth_to_normal( - depths: Tensor, # [C, H, W, 1] - viewmats: Tensor, # [C, 4, 4] - Ks: Tensor, # [C, 3, 3] - near_plane: float = 0.01, - far_plane: float = 1e10, -) -> Tensor: - height, width = depths.shape[1:3] - - normals = [] - for cid, depth in enumerate(depths): - FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) - FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) - world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = getProjectionMatrix( - znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device - ).transpose(0, 1) - full_proj_transform = ( - world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) - ).squeeze(0) - normal = _depth_to_normal(depth, world_view_transform, full_proj_transform) - normals.append(normal) - normals = torch.stack(normals, dim=0) - return normals diff --git a/tests/test_basic.py b/tests/test_basic.py index a66828a8b..e0cd904b5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -431,6 +431,7 @@ def test_rasterize_to_pixels(test_data, channels: int): fully_fused_projection, isect_offset_encode, isect_tiles, + quat_scale_to_covar_preci, rasterize_to_pixels, ) @@ -448,9 +449,11 @@ def test_rasterize_to_pixels(test_data, channels: int): colors = torch.randn(C, len(means), channels, device=device) backgrounds = torch.rand((C, colors.shape[-1]), device=device) + covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) + # Project Gaussians to 2D radii, means2d, depths, normals, conics, compensations = fully_fused_projection( - means, None, quats, scales, viewmats, Ks, width, height + means, covars, None, None, viewmats, Ks, width, height ) opacities = opacities.repeat(C, 1) @@ -504,7 +507,7 @@ def test_rasterize_to_pixels(test_data, channels: int): v_means2d, v_conics, v_colors, v_opacities, v_backgrounds = torch.autograd.grad( (render_colors * v_render_colors).sum() + (render_alphas * v_render_alphas).sum(), - (means2d, conics, colors, opacities, backgrounds), + (means2d, normals, conics, colors, opacities, backgrounds), ) ( _v_means2d, @@ -515,7 +518,7 @@ def test_rasterize_to_pixels(test_data, channels: int): ) = torch.autograd.grad( (_render_colors * v_render_colors).sum() + (_render_alphas * v_render_alphas).sum(), - (means2d, conics, colors, opacities, backgrounds), + (means2d, normals, conics, colors, opacities, backgrounds), ) torch.testing.assert_close(v_means2d, _v_means2d, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v_conics, _v_conics, rtol=1e-3, atol=1e-3) From 5b5a7c32e630a8a9a7a4c394336b688163867d46 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 13 Sep 2024 13:59:43 -0700 Subject: [PATCH 63/66] reduce diff test_basic --- tests/test_basic.py | 177 ++++++++++++++++++++++++++++++-------------- 1 file changed, 121 insertions(+), 56 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index e0cd904b5..05931a3c8 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -165,9 +165,10 @@ def test_proj(test_data, ortho: bool): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +@pytest.mark.parametrize("fused", [True]) @pytest.mark.parametrize("calc_compensations", [False, True]) @pytest.mark.parametrize("ortho", [True, False]) -def test_projection(test_data, calc_compensations: bool, ortho: bool): +def test_projection(test_data, fused: bool, calc_compensations: bool, ortho: bool): from gsplat.cuda._torch_impl import _fully_fused_projection, _quat_to_rotmat from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci @@ -187,18 +188,33 @@ def test_projection(test_data, calc_compensations: bool, ortho: bool): means.requires_grad = True # forward - radii, means2d, depths, normals, conics, compensations = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - calc_compensations=calc_compensations, - ortho=ortho, - ) + if fused: + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + calc_compensations=calc_compensations, + ortho=ortho, + ) + else: + covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( + means, + covars, + None, + None, + viewmats, + Ks, + width, + height, + calc_compensations=calc_compensations, + ortho=ortho, + ) ( _radii, _means2d, @@ -261,13 +277,14 @@ def test_projection(test_data, calc_compensations: bool, ortho: bool): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +@pytest.mark.parametrize("fused", [True]) @pytest.mark.parametrize("sparse_grad", [False, True]) @pytest.mark.parametrize("calc_compensations", [False, True]) @pytest.mark.parametrize("ortho", [True, False]) def test_fully_fused_projection_packed( - test_data, sparse_grad: bool, calc_compensations: bool, ortho: bool + test_data, fused: bool, sparse_grad: bool, calc_compensations: bool, ortho: bool ): - from gsplat.cuda._wrapper import fully_fused_projection + from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci torch.manual_seed(42) @@ -285,47 +302,95 @@ def test_fully_fused_projection_packed( means.requires_grad = True # forward - ( - camera_ids, - gaussian_ids, - radii, - means2d, - depths, - normals, - conics, - compensations, - ) = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - packed=True, - sparse_grad=sparse_grad, - calc_compensations=calc_compensations, - ) - ( - _radii, - _means2d, - _depths, - _normals, - _conics, - _compensations, - ) = fully_fused_projection( - means, - None, - quats, - scales, - viewmats, - Ks, - width, - height, - packed=False, - calc_compensations=calc_compensations, - ) + if fused: + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=True, + sparse_grad=sparse_grad, + calc_compensations=calc_compensations, + ortho=ortho, + ) + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = fully_fused_projection( + means, + None, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=False, + calc_compensations=calc_compensations, + ortho=ortho, + ) + else: + covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) = fully_fused_projection( + means, + covars, + None, + None, + viewmats, + Ks, + width, + height, + packed=True, + sparse_grad=sparse_grad, + calc_compensations=calc_compensations, + ortho=ortho, + ) + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = fully_fused_projection( + means, + covars, + None, + None, + viewmats, + Ks, + width, + height, + packed=False, + calc_compensations=calc_compensations, + ortho=ortho, + ) # recover packed tensors to full matrices for testing __radii = torch.sparse_coo_tensor( From 9c4186a4201c3ac12a507b86452ce430a2942775 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 13 Sep 2024 14:04:33 -0700 Subject: [PATCH 64/66] tests not passing --- tests/test_basic.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 05931a3c8..8d4eb7290 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -496,7 +496,6 @@ def test_rasterize_to_pixels(test_data, channels: int): fully_fused_projection, isect_offset_encode, isect_tiles, - quat_scale_to_covar_preci, rasterize_to_pixels, ) @@ -514,11 +513,9 @@ def test_rasterize_to_pixels(test_data, channels: int): colors = torch.randn(C, len(means), channels, device=device) backgrounds = torch.rand((C, colors.shape[-1]), device=device) - covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) - # Project Gaussians to 2D radii, means2d, depths, normals, conics, compensations = fully_fused_projection( - means, covars, None, None, viewmats, Ks, width, height + means, None, quats, scales, viewmats, Ks, width, height ) opacities = opacities.repeat(C, 1) @@ -572,7 +569,7 @@ def test_rasterize_to_pixels(test_data, channels: int): v_means2d, v_conics, v_colors, v_opacities, v_backgrounds = torch.autograd.grad( (render_colors * v_render_colors).sum() + (render_alphas * v_render_alphas).sum(), - (means2d, normals, conics, colors, opacities, backgrounds), + (means2d, conics, colors, opacities, backgrounds), ) ( _v_means2d, @@ -583,7 +580,7 @@ def test_rasterize_to_pixels(test_data, channels: int): ) = torch.autograd.grad( (_render_colors * v_render_colors).sum() + (_render_alphas * v_render_alphas).sum(), - (means2d, normals, conics, colors, opacities, backgrounds), + (means2d, conics, colors, opacities, backgrounds), ) torch.testing.assert_close(v_means2d, _v_means2d, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v_conics, _v_conics, rtol=1e-3, atol=1e-3) From 0c5e3ed7d37287a3307807dc8fc99062e866f0e3 Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Fri, 13 Sep 2024 14:12:35 -0700 Subject: [PATCH 65/66] all tests passed --- gsplat/rendering.py | 6 +++--- tests/test_basic.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index e8b04dead..acee34ef7 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -585,7 +585,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso normals_rend = render_colors[..., -4:-1] normals_surf = depth_to_normal( render_colors[..., -1:], - camtoworlds, + torch.inverse(viewmats), Ks, ) normals_surf = normals_surf * (render_alphas).detach() @@ -824,7 +824,7 @@ def _rasterization( normals_rend = render_colors[..., -4:-1] normals_surf = depth_to_normal( render_colors[..., -1:], - camtoworlds, + torch.inverse(viewmats), Ks, ) normals_surf = normals_surf * (render_alphas).detach() @@ -1489,7 +1489,7 @@ def rasterization_2dgs_inria_wrapper( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - normals_surf = depth_to_normal(render_depth, viewmats, Ks) + normals_surf = depth_to_normal(render_depth, torch.inverse(viewmats), Ks) normals_surf = normals_surf * (render_alphas).detach() render_colors = torch.cat([render_colors, render_depth], dim=-1) diff --git a/tests/test_basic.py b/tests/test_basic.py index 8d4eb7290..fb04fdcd4 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -169,7 +169,7 @@ def test_proj(test_data, ortho: bool): @pytest.mark.parametrize("calc_compensations", [False, True]) @pytest.mark.parametrize("ortho", [True, False]) def test_projection(test_data, fused: bool, calc_compensations: bool, ortho: bool): - from gsplat.cuda._torch_impl import _fully_fused_projection, _quat_to_rotmat + from gsplat.cuda._torch_impl import _fully_fused_projection from gsplat.cuda._wrapper import fully_fused_projection, quat_scale_to_covar_preci torch.manual_seed(42) From 38f253250254e5b37ba813bdebdadc21bd9918ce Mon Sep 17 00:00:00 2001 From: Jeffrey Hu Date: Tue, 24 Sep 2024 11:39:39 -0700 Subject: [PATCH 66/66] summarize stats --- examples/benchmarks/compression/mcmc.sh | 2 +- examples/benchmarks/compression/mcmc_tt.sh | 2 +- examples/benchmarks/normal/2dgs_dtu.sh | 17 +++++++++++ examples/benchmarks/normal/mcmc_dtu.sh | 28 +++++++++++++++++++ examples/benchmarks/normal/mcmc_normal.sh | 5 +++- .../{compression => }/summarize_stats.py | 20 +++++++------ 6 files changed, 62 insertions(+), 12 deletions(-) create mode 100644 examples/benchmarks/normal/2dgs_dtu.sh create mode 100644 examples/benchmarks/normal/mcmc_dtu.sh rename examples/benchmarks/{compression => }/summarize_stats.py (67%) diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh index 4c7165f3d..e55ef6aab 100644 --- a/examples/benchmarks/compression/mcmc.sh +++ b/examples/benchmarks/compression/mcmc.sh @@ -49,7 +49,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR + python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/compression/mcmc_tt.sh b/examples/benchmarks/compression/mcmc_tt.sh index 054920929..e637ae434 100644 --- a/examples/benchmarks/compression/mcmc_tt.sh +++ b/examples/benchmarks/compression/mcmc_tt.sh @@ -42,7 +42,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST + python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/normal/2dgs_dtu.sh b/examples/benchmarks/normal/2dgs_dtu.sh new file mode 100644 index 000000000..0b8adc296 --- /dev/null +++ b/examples/benchmarks/normal/2dgs_dtu.sh @@ -0,0 +1,17 @@ +SCENE_DIR="data/DTU" +SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122" + +RESULT_DIR="results/benchmark_dtu_2dgs" + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor 1 \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/normal/mcmc_dtu.sh b/examples/benchmarks/normal/mcmc_dtu.sh new file mode 100644 index 000000000..140412b2e --- /dev/null +++ b/examples/benchmarks/normal/mcmc_dtu.sh @@ -0,0 +1,28 @@ +SCENE_DIR="data/DTU" +SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122" +RENDER_TRAJ_PATH="ellipse" + +RESULT_DIR="results/benchmark_dtu_mcmc_0.25M_normal" +CAP_MAX=250000 + +# RESULT_DIR="results/benchmark_dtu_mcmc_0.5M_normal" +# CAP_MAX=500000 + +# RESULT_DIR="results/benchmark_dtu_mcmc_1M_normal" +# CAP_MAX=1000000 + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \ + --strategy.cap-max $CAP_MAX \ + --normal_consistency_loss \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/normal/mcmc_normal.sh b/examples/benchmarks/normal/mcmc_normal.sh index 1de152f34..3127099d6 100644 --- a/examples/benchmarks/normal/mcmc_normal.sh +++ b/examples/benchmarks/normal/mcmc_normal.sh @@ -1,5 +1,5 @@ SCENE_DIR="data/360_v2" -SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers" RESULT_DIR="results/benchmark_normal" RENDER_TRAJ_PATH="ellipse" @@ -20,3 +20,6 @@ do --data_dir $SCENE_DIR/$SCENE/ \ --result_dir $RESULT_DIR/$SCENE/ done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/summarize_stats.py similarity index 67% rename from examples/benchmarks/compression/summarize_stats.py rename to examples/benchmarks/summarize_stats.py index d11dbed6f..e76ebc90f 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/summarize_stats.py @@ -8,11 +8,8 @@ import tyro -def main(results_dir: str, scenes: List[str]): - print("scenes:", scenes) - stage = "compress" - - summary = defaultdict(list) +def main(results_dir: str, scenes: List[str], stage: str = "val"): + stats_all = defaultdict(list) for scene in scenes: scene_dir = os.path.join(results_dir, scene) @@ -25,15 +22,20 @@ def main(results_dir: str, scenes: List[str]): f"stat -c%s {zip_path}", shell=True, capture_output=True ) size = int(out.stdout) - summary["size"].append(size) + stats_all["size"].append(size) with open(os.path.join(scene_dir, f"stats/{stage}_step29999.json"), "r") as f: stats = json.load(f) for k, v in stats.items(): - summary[k].append(v) + stats_all[k].append(v) + + summary = {"scenes": scenes} + for k, v in stats_all.items(): + summary[k] = np.mean(v) + print(summary) - for k, v in summary.items(): - print(k, np.mean(v)) + with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f: + json.dump(summary, f, indent=2) if __name__ == "__main__":