diff --git a/examples/benchmarks/compression/mcmc.sh b/examples/benchmarks/compression/mcmc.sh index 4c7165f3d..e55ef6aab 100644 --- a/examples/benchmarks/compression/mcmc.sh +++ b/examples/benchmarks/compression/mcmc.sh @@ -49,7 +49,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR + python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/compression/mcmc_tt.sh b/examples/benchmarks/compression/mcmc_tt.sh index 054920929..e637ae434 100644 --- a/examples/benchmarks/compression/mcmc_tt.sh +++ b/examples/benchmarks/compression/mcmc_tt.sh @@ -42,7 +42,7 @@ done if command -v zip &> /dev/null then echo "Zipping results" - python benchmarks/compression/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST + python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage compress else echo "zip command not found, skipping zipping" fi \ No newline at end of file diff --git a/examples/benchmarks/normal/2dgs_dtu.sh b/examples/benchmarks/normal/2dgs_dtu.sh new file mode 100644 index 000000000..0b8adc296 --- /dev/null +++ b/examples/benchmarks/normal/2dgs_dtu.sh @@ -0,0 +1,17 @@ +SCENE_DIR="data/DTU" +SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122" + +RESULT_DIR="results/benchmark_dtu_2dgs" + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor 1 \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/normal/mcmc_dtu.sh b/examples/benchmarks/normal/mcmc_dtu.sh new file mode 100644 index 000000000..140412b2e --- /dev/null +++ b/examples/benchmarks/normal/mcmc_dtu.sh @@ -0,0 +1,28 @@ +SCENE_DIR="data/DTU" +SCENE_LIST="scan24 scan37 scan40 scan55 scan63 scan65 scan69 scan83 scan97 scan105 scan106 scan110 scan114 scan118 scan122" +RENDER_TRAJ_PATH="ellipse" + +RESULT_DIR="results/benchmark_dtu_mcmc_0.25M_normal" +CAP_MAX=250000 + +# RESULT_DIR="results/benchmark_dtu_mcmc_0.5M_normal" +# CAP_MAX=500000 + +# RESULT_DIR="results/benchmark_dtu_mcmc_1M_normal" +# CAP_MAX=1000000 + +for SCENE in $SCENE_LIST; +do + echo "Running $SCENE" + + # train and eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor 1 \ + --strategy.cap-max $CAP_MAX \ + --normal_consistency_loss \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/normal/mcmc_normal.sh b/examples/benchmarks/normal/mcmc_normal.sh new file mode 100644 index 000000000..3127099d6 --- /dev/null +++ b/examples/benchmarks/normal/mcmc_normal.sh @@ -0,0 +1,25 @@ +SCENE_DIR="data/360_v2" +SCENE_LIST="garden bicycle stump bonsai counter kitchen room treehill flowers" + +RESULT_DIR="results/benchmark_normal" +RENDER_TRAJ_PATH="ellipse" + +for SCENE in $SCENE_LIST; +do + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + echo "Running $SCENE" + + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py mcmc --disable_viewer --data_factor $DATA_FACTOR \ + --normal_consistency_loss \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir $SCENE_DIR/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ +done + +echo "Summarizing results" +python benchmarks/summarize_stats.py --results_dir $RESULT_DIR --scenes $SCENE_LIST --stage val diff --git a/examples/benchmarks/compression/summarize_stats.py b/examples/benchmarks/summarize_stats.py similarity index 67% rename from examples/benchmarks/compression/summarize_stats.py rename to examples/benchmarks/summarize_stats.py index d11dbed6f..e76ebc90f 100644 --- a/examples/benchmarks/compression/summarize_stats.py +++ b/examples/benchmarks/summarize_stats.py @@ -8,11 +8,8 @@ import tyro -def main(results_dir: str, scenes: List[str]): - print("scenes:", scenes) - stage = "compress" - - summary = defaultdict(list) +def main(results_dir: str, scenes: List[str], stage: str = "val"): + stats_all = defaultdict(list) for scene in scenes: scene_dir = os.path.join(results_dir, scene) @@ -25,15 +22,20 @@ def main(results_dir: str, scenes: List[str]): f"stat -c%s {zip_path}", shell=True, capture_output=True ) size = int(out.stdout) - summary["size"].append(size) + stats_all["size"].append(size) with open(os.path.join(scene_dir, f"stats/{stage}_step29999.json"), "r") as f: stats = json.load(f) for k, v in stats.items(): - summary[k].append(v) + stats_all[k].append(v) + + summary = {"scenes": scenes} + for k, v in stats_all.items(): + summary[k] = np.mean(v) + print(summary) - for k, v in summary.items(): - print(k, np.mean(v)) + with open(os.path.join(results_dir, f"{stage}_summary.json"), "w") as f: + json.dump(summary, f, indent=2) if __name__ == "__main__": diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3e544201b..8711e45b3 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -154,6 +154,13 @@ class Config: # Weight for depth loss depth_lambda: float = 1e-2 + # Enable normal consistency loss. (experimental) + normal_consistency_loss: bool = False + # Weight for normal consistency loss + normal_consistency_lambda: float = 0.05 + # Start applying normal consistency loss after this iteration + normal_consistency_start_iter: int = 7000 + # Dump information to tensorboard every this steps tb_every: int = 100 # Save training images to tensorboard @@ -273,6 +280,12 @@ def __init__( self.world_size = world_size self.device = f"cuda:{local_rank}" + self.render_mode = "RGB" + if cfg.depth_loss: + self.render_mode = "RGB+ED" + if cfg.normal_consistency_loss: + self.render_mode = "RGB+ED+N" + # Where to dump results. os.makedirs(cfg.result_dir, exist_ok=True) @@ -587,13 +600,10 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode=self.render_mode, masks=masks, ) - if renders.shape[-1] == 4: - colors, depths = renders[..., 0:3], renders[..., 3:4] - else: - colors, depths = renders, None + colors = renders[..., :3] if cfg.use_bilateral_grid: grid_y, grid_x = torch.meshgrid( @@ -623,6 +633,7 @@ def train(self): ) loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda if cfg.depth_loss: + depths = renders[..., -1:] # query depths from depth map points = torch.stack( [ @@ -641,6 +652,14 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + normalconsistencyloss = ( + 1 - (normals_rend * normals_surf).sum(dim=-1) + ).mean() + if step > cfg.normal_consistency_start_iter: + loss += normalconsistencyloss * cfg.normal_consistency_lambda if cfg.use_bilateral_grid: tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss @@ -687,6 +706,12 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.normal_consistency_loss: + self.writer.add_scalar( + "train/normalconsistencyloss", + normalconsistencyloss.item(), + step, + ) if cfg.use_bilateral_grid: self.writer.add_scalar("train/tvloss", tvloss.item(), step) if cfg.tb_save_image: @@ -819,7 +844,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _ = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -827,13 +852,23 @@ def eval(self, step: int, stage: str = "val"): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, + render_mode=self.render_mode, masks=masks, ) # [1, H, W, 3] torch.cuda.synchronize() ellipse_time += time.time() - tic - colors = torch.clamp(colors, 0.0, 1.0) + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) canvas_list = [pixels, colors] + if cfg.depth_loss: + depths = renders[..., -1:] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(depths) + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) if world_rank == 0: # write images @@ -927,7 +962,7 @@ def render_traj(self, step: int): camtoworlds = camtoworlds_all[i : i + 1] Ks = K[None] - renders, _, _ = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -935,12 +970,20 @@ def render_traj(self, step: int): sh_degree=cfg.sh_degree, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - render_mode="RGB+ED", + render_mode=self.render_mode, ) # [1, H, W, 4] - colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] - depths = renders[..., 3:4] # [1, H, W, 1] - depths = (depths - depths.min()) / (depths.max() - depths.min()) - canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) + canvas_list = [colors] + if cfg.depth_loss: + depths = renders[..., -1:] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list.append(depths) + if cfg.normal_consistency_loss: + normals_rend = info["normals_rend"] + normals_surf = info["normals_surf"] + canvas_list.extend([normals_rend * 0.5 + 0.5]) + canvas_list.extend([normals_surf * 0.5 + 0.5]) # write images canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 892c6a66f..d5b090add 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -249,7 +249,8 @@ def _world_to_cam( def _fully_fused_projection( means: Tensor, # [N, 3] - covars: Tensor, # [N, 3, 3] + quats: Tensor, + scales: Tensor, viewmats: Tensor, # [C, 4, 4] Ks: Tensor, # [C, 3, 3] width: int, @@ -267,6 +268,10 @@ def _fully_fused_projection( This is a minimal implementation of fully fused version, which has more arguments. Not all arguments are supported. """ + covars, _ = _quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] + normals = _quat_to_rotmat(quats)[..., 2] # [N, 3] + normals = normals.repeat(viewmats.shape[0], 1, 1) # [C, N, 3] + means_c, covars_c = _world_to_cam(means, covars, viewmats) if camera_model == "ortho": @@ -324,7 +329,7 @@ def _fully_fused_projection( radius[~inside] = 0.0 radii = radius.int() - return radii, means2d, depths, conics, compensations + return radii, means2d, depths, normals, conics, compensations @torch.no_grad() diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 1c3826110..a1d97fd36 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -780,7 +780,7 @@ def forward( ) # "covars" and {"quats", "scales"} are mutually exclusive - radii, means2d, depths, conics, compensations = _make_lazy_cuda_func( + radii, means2d, depths, normals, conics, compensations = _make_lazy_cuda_func( "fully_fused_projection_fwd" )( means, @@ -808,10 +808,12 @@ def forward( ctx.eps2d = eps2d ctx.camera_model_type = camera_model_type - return radii, means2d, depths, conics, compensations + return radii, means2d, depths, normals, conics, compensations @staticmethod - def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): + def backward( + ctx, v_radii, v_means2d, v_depths, v_normals, v_conics, v_compensations + ): ( means, covars, @@ -847,6 +849,7 @@ def backward(ctx, v_radii, v_means2d, v_depths, v_conics, v_compensations): compensations, v_means2d.contiguous(), v_depths.contiguous(), + v_normals.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad @@ -1043,6 +1046,7 @@ def forward( radii, means2d, depths, + normals, conics, compensations, ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd")( @@ -1081,7 +1085,16 @@ def forward( ctx.sparse_grad = sparse_grad ctx.camera_model_type = camera_model_type - return camera_ids, gaussian_ids, radii, means2d, depths, conics, compensations + return ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + normals, + conics, + compensations, + ) @staticmethod def backward( @@ -1091,6 +1104,7 @@ def backward( v_radii, v_means2d, v_depths, + v_normals, v_conics, v_compensations, ): @@ -1133,6 +1147,7 @@ def backward( compensations, v_means2d.contiguous(), v_depths.contiguous(), + v_normals.contiguous(), v_conics.contiguous(), v_compensations, ctx.needs_input_grad[4], # viewmats_requires_grad diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index cf0dc8751..897be4f6a 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -96,6 +96,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -139,6 +140,7 @@ fully_fused_projection_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [C, N, 2] const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional const bool viewmats_requires_grad @@ -258,6 +260,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -302,6 +305,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, diff --git a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu index b5757ff40..d107e4360 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_bwd.cu @@ -38,6 +38,7 @@ __global__ void fully_fused_projection_bwd_kernel( // grad outputs const T *__restrict__ v_means2d, // [C, N, 2] const T *__restrict__ v_depths, // [C, N] + const T *__restrict__ v_normals, // [C, N, 3] const T *__restrict__ v_conics, // [C, N, 3] const T *__restrict__ v_compensations, // [C, N] optional // grad inputs @@ -223,9 +224,16 @@ __global__ void fully_fused_projection_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp( - quat, scale, rotmat, v_covar, v_quat, v_scale + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals + v_normals += idx * 3; + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat ); + warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); if (warp_group_g.thread_rank() == 0) { @@ -283,6 +291,7 @@ fully_fused_projection_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [C, N, 2] const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] const torch::Tensor &v_conics, // [C, N, 3] const at::optional &v_compensations, // [C, N] optional const bool viewmats_requires_grad @@ -302,6 +311,7 @@ fully_fused_projection_bwd_tensor( GSPLAT_CHECK_INPUT(conics); GSPLAT_CHECK_INPUT(v_means2d); GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); GSPLAT_CHECK_INPUT(v_conics); if (compensations.has_value()) { GSPLAT_CHECK_INPUT(compensations.value()); @@ -352,6 +362,7 @@ fully_fused_projection_bwd_tensor( : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() diff --git a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu index c651e803d..0acf21127 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_fwd.cu @@ -37,6 +37,7 @@ __global__ void fully_fused_projection_fwd_kernel( int32_t *__restrict__ radii, // [C, N] T *__restrict__ means2d, // [C, N, 2] T *__restrict__ depths, // [C, N] + T *__restrict__ normals, // [C, N, 3] T *__restrict__ conics, // [C, N, 3] T *__restrict__ compensations // [C, N] optional ) { @@ -77,6 +78,7 @@ __global__ void fully_fused_projection_fwd_kernel( // transform Gaussian covariance to camera space mat3 covar; + vec3 normal; if (covars != nullptr) { covars += gid * 6; covar = mat3( @@ -94,9 +96,11 @@ __global__ void fully_fused_projection_fwd_kernel( // compute from quaternions and scales quats += gid * 4; scales += gid * 3; - quat_scale_to_covar_preci( - glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr - ); + quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), + &covar, nullptr); + + mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + normal = rotmat[2]; } mat3 covar_c; covar_world_to_cam(R, covar, covar_c); @@ -185,6 +189,9 @@ __global__ void fully_fused_projection_fwd_kernel( means2d[idx * 2] = mean2d.x; means2d[idx * 2 + 1] = mean2d.y; depths[idx] = mean_c.z; + normals[idx * 3] = normal.x; + normals[idx * 3 + 1] = normal.y; + normals[idx * 3 + 2] = normal.z; conics[idx * 3] = covar2d_inv[0][0]; conics[idx * 3 + 1] = covar2d_inv[0][1]; conics[idx * 3 + 2] = covar2d_inv[1][1]; @@ -198,6 +205,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -235,6 +243,7 @@ fully_fused_projection_fwd_tensor( torch::empty({C, N}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); torch::Tensor depths = torch::empty({C, N}, means.options()); + torch::Tensor normals = torch::empty({C, N, 3}, means.options()); torch::Tensor conics = torch::empty({C, N, 3}, means.options()); torch::Tensor compensations; if (calc_compensations) { @@ -265,11 +274,12 @@ fully_fused_projection_fwd_tensor( radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr ); } - return std::make_tuple(radii, means2d, depths, conics, compensations); + return std::make_tuple(radii, means2d, depths, normals, conics, compensations); } } // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu index e5a0172fe..c43040865 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_bwd.cu @@ -40,6 +40,7 @@ __global__ void fully_fused_projection_packed_bwd_kernel( // grad outputs const T *__restrict__ v_means2d, // [nnz, 2] const T *__restrict__ v_depths, // [nnz] + const T *__restrict__ v_normals, // [nnz, 3] const T *__restrict__ v_conics, // [nnz, 3] const T *__restrict__ v_compensations, // [nnz] optional const bool sparse_grad, // whether the outputs are in COO format [nnz, ...] @@ -215,9 +216,16 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp( - quat, scale, rotmat, v_covar, v_quat, v_scale + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals + v_normals += idx * 3; + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat ); + v_quats += idx * 4; v_scales += idx * 3; v_quats[0] = v_quat[0]; @@ -260,9 +268,16 @@ __global__ void fully_fused_projection_packed_bwd_kernel( mat3 rotmat = quat_to_rotmat(quat); vec4 v_quat(0.f); vec3 v_scale(0.f); - quat_scale_to_covar_vjp( - quat, scale, rotmat, v_covar, v_quat, v_scale + quat_scale_to_covar_vjp(quat, scale, rotmat, v_covar, v_quat, v_scale); + + // add contribution from v_normals + v_normals += idx * 3; + quat_to_rotmat_vjp( + quat, + mat3(0.f, 0.f, 0.f, 0.f, 0.f, 0.f, v_normals[0], v_normals[1], v_normals[2]), + v_quat ); + warpSum(v_quat, warp_group_g); warpSum(v_scale, warp_group_g); if (warp_group_g.thread_rank() == 0) { @@ -323,6 +338,7 @@ fully_fused_projection_packed_bwd_tensor( // grad outputs const torch::Tensor &v_means2d, // [nnz, 2] const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] const torch::Tensor &v_conics, // [nnz, 3] const at::optional &v_compensations, // [nnz] optional const bool viewmats_requires_grad, @@ -344,6 +360,7 @@ fully_fused_projection_packed_bwd_tensor( GSPLAT_CHECK_INPUT(conics); GSPLAT_CHECK_INPUT(v_means2d); GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); GSPLAT_CHECK_INPUT(v_conics); if (compensations.has_value()) { GSPLAT_CHECK_INPUT(compensations.value()); @@ -409,6 +426,7 @@ fully_fused_projection_packed_bwd_tensor( : nullptr, v_means2d.data_ptr(), v_depths.data_ptr(), + v_normals.data_ptr(), v_conics.data_ptr(), v_compensations.has_value() ? v_compensations.value().data_ptr() diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 4d8609f05..c45dacd23 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -43,6 +43,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( int32_t *__restrict__ radii, // [nnz] T *__restrict__ means2d, // [nnz, 2] T *__restrict__ depths, // [nnz] + T *__restrict__ normals, // [nnz, 3] T *__restrict__ conics, // [nnz, 3] T *__restrict__ compensations // [nnz] optional ) { @@ -89,6 +90,7 @@ __global__ void fully_fused_projection_packed_fwd_kernel( mat2 covar2d; vec2 mean2d; mat2 covar2d_inv; + vec3 normal; T compensation; T det; if (valid) { @@ -112,9 +114,10 @@ __global__ void fully_fused_projection_packed_fwd_kernel( // if not then compute it from quaternions and scales quats += col_idx * 4; scales += col_idx * 3; - quat_scale_to_covar_preci( - glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr - ); + quat_scale_to_covar_preci(glm::make_vec4(quats), glm::make_vec3(scales), &covar, nullptr); + + mat3 rotmat = quat_to_rotmat(glm::make_vec4(quats)); + normal = rotmat[2]; } mat3 covar_c; covar_world_to_cam(R, covar, covar_c); @@ -227,6 +230,9 @@ __global__ void fully_fused_projection_packed_fwd_kernel( means2d[thread_data * 2] = mean2d.x; means2d[thread_data * 2 + 1] = mean2d.y; depths[thread_data] = mean_c.z; + normals[thread_data * 3] = normal.x; + normals[thread_data * 3 + 1] = normal.y; + normals[thread_data * 3 + 2] = normal.z; conics[thread_data * 3] = covar2d_inv[0][0]; conics[thread_data * 3 + 1] = covar2d_inv[0][1]; conics[thread_data * 3 + 2] = covar2d_inv[1][1]; @@ -254,6 +260,7 @@ std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, + torch::Tensor, torch::Tensor> fully_fused_projection_packed_fwd_tensor( const torch::Tensor &means, // [N, 3] @@ -327,6 +334,7 @@ fully_fused_projection_packed_fwd_tensor( nullptr, nullptr, nullptr, + nullptr, nullptr ); block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); @@ -343,6 +351,7 @@ fully_fused_projection_packed_fwd_tensor( torch::empty({nnz}, means.options().dtype(torch::kInt32)); torch::Tensor means2d = torch::empty({nnz, 2}, means.options()); torch::Tensor depths = torch::empty({nnz}, means.options()); + torch::Tensor normals = torch::empty({nnz, 3}, means.options()); torch::Tensor conics = torch::empty({nnz, 3}, means.options()); torch::Tensor compensations; if (calc_compensations) { @@ -376,6 +385,7 @@ fully_fused_projection_packed_fwd_tensor( radii.data_ptr(), means2d.data_ptr(), depths.data_ptr(), + normals.data_ptr(), conics.data_ptr(), calc_compensations ? compensations.data_ptr() : nullptr ); @@ -390,6 +400,7 @@ fully_fused_projection_packed_fwd_tensor( radii, means2d, depths, + normals, conics, compensations ); diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 83ec6e77b..afb180279 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -5,6 +5,7 @@ import torch.distributed import torch.nn.functional as F from torch import Tensor +import torch.nn.functional as F from typing_extensions import Literal from .cuda._wrapper import ( @@ -43,7 +44,7 @@ def rasterization( packed: bool = True, tile_size: int = 16, backgrounds: Optional[Tensor] = None, - render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"] = "RGB", sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", @@ -239,7 +240,7 @@ def rasterization( assert opacities.shape == (N,), opacities.shape assert viewmats.shape == (C, 4, 4), viewmats.shape assert Ks.shape == (C, 3, 3), Ks.shape - assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"], render_mode def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tensor: view_list = list( @@ -321,13 +322,14 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso radii, means2d, depths, + normals, conics, compensations, ) = proj_results opacities = opacities[gaussian_ids] # [nnz] else: # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. - radii, means2d, depths, conics, compensations = proj_results + radii, means2d, depths, normals, conics, compensations = proj_results opacities = opacities.repeat(C, 1) # [C, N] camera_ids, gaussian_ids = None, None @@ -488,6 +490,12 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso colors = depths[..., None] if backgrounds is not None: backgrounds = torch.zeros(C, 1, device=backgrounds.device) + elif render_mode in ["RGB+ED+N"]: + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + if backgrounds is not None: + backgrounds = torch.cat( + [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + ) else: # RGB pass @@ -569,16 +577,24 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso packed=packed, absgrad=absgrad, ) - if render_mode in ["ED", "RGB+ED"]: + + if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: # normalize the accumulated depth to get the expected depth - render_colors = torch.cat( - [ - render_colors[..., :-1], - render_colors[..., -1:] / render_alphas.clamp(min=1e-10), - ], - dim=-1, + render_colors[..., -1:] /= render_alphas.clamp(min=1e-10) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + render_colors[..., -1:], + torch.inverse(viewmats), + Ks, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } ) - return render_colors, render_alphas, meta @@ -598,7 +614,7 @@ def _rasterization( sh_degree: Optional[int] = None, tile_size: int = 16, backgrounds: Optional[Tensor] = None, - render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"] = "RGB", rasterize_mode: Literal["classic", "antialiased"] = "classic", channel_chunk: int = 32, batch_per_iter: int = 100, @@ -632,7 +648,7 @@ def _rasterization( assert opacities.shape == (N,), opacities.shape assert viewmats.shape == (C, 4, 4), viewmats.shape assert Ks.shape == (C, 3, 3), Ks.shape - assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED", "RGB+ED+N"], render_mode if sh_degree is None: # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] @@ -651,10 +667,10 @@ def _rasterization( # Project Gaussians to 2D. # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. - covars, _ = _quat_scale_to_covar_preci(quats, scales, True, False, triu=False) - radii, means2d, depths, conics, compensations = _fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = _fully_fused_projection( means, - covars, + quats, + scales, viewmats, Ks, width, @@ -722,6 +738,12 @@ def _rasterization( colors = depths[..., None] if backgrounds is not None: backgrounds = torch.zeros(C, 1, device=backgrounds.device) + elif render_mode in ["RGB+ED+N"]: + colors = torch.cat((colors, normals, depths[..., None]), dim=-1) + if backgrounds is not None: + backgrounds = torch.cat( + [backgrounds, torch.zeros(C, 1, device=backgrounds.device)], dim=-1 + ) else: # RGB pass if colors.shape[-1] > channel_chunk: @@ -795,6 +817,23 @@ def _rasterization( "tile_size": tile_size, "n_cameras": C, } + if render_mode in ["ED", "RGB+ED", "RGB+ED+N"]: + # normalize the accumulated depth to get the expected depth + render_colors[..., -1:] /= render_alphas.clamp(min=1e-10) + if render_mode in ["RGB+ED+N"]: + normals_rend = render_colors[..., -4:-1] + normals_surf = depth_to_normal( + render_colors[..., -1:], + torch.inverse(viewmats), + Ks, + ) + normals_surf = normals_surf * (render_alphas).detach() + meta.update( + { + "normals_rend": normals_rend, + "normals_surf": normals_surf, + } + ) return render_colors, render_alphas, meta @@ -927,6 +966,9 @@ def rasterization_inria_wrapper( device = means.device channels = colors.shape[-1] + # rasterization from inria does not do normalization internally + quats = F.normalize(quats, dim=-1) # [N, 4] + render_colors = [] for cid in range(C): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) diff --git a/tests/test_basic.py b/tests/test_basic.py index 22d2ee227..f086087dc 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -199,7 +199,7 @@ def test_projection( # forward if fused: - radii, means2d, depths, conics, compensations = fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( means, None, quats, @@ -213,7 +213,7 @@ def test_projection( ) else: covars, _ = quat_scale_to_covar_preci(quats, scales, triu=True) # [N, 6] - radii, means2d, depths, conics, compensations = fully_fused_projection( + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( means, covars, None, @@ -225,10 +225,17 @@ def test_projection( calc_compensations=calc_compensations, camera_model=camera_model, ) - _covars, _ = quat_scale_to_covar_preci(quats, scales, triu=False) # [N, 3, 3] - _radii, _means2d, _depths, _conics, _compensations = _fully_fused_projection( + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = _fully_fused_projection( means, - _covars, + quats, + scales, viewmats, Ks, width, @@ -242,6 +249,7 @@ def test_projection( torch.testing.assert_close(radii, _radii, rtol=0, atol=1) torch.testing.assert_close(means2d[valid], _means2d[valid], rtol=1e-4, atol=1e-4) torch.testing.assert_close(depths[valid], _depths[valid], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(normals[valid], _normals[valid], rtol=1e-4, atol=1e-4) torch.testing.assert_close(conics[valid], _conics[valid], rtol=1e-4, atol=1e-4) if calc_compensations: torch.testing.assert_close( @@ -251,12 +259,14 @@ def test_projection( # backward v_means2d = torch.randn_like(means2d) * valid[..., None] v_depths = torch.randn_like(depths) * valid + v_normals = torch.randn_like(normals) * valid[..., None] v_conics = torch.randn_like(conics) * valid[..., None] if calc_compensations: v_compensations = torch.randn_like(compensations) * valid v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d).sum() + (depths * v_depths).sum() + + (normals * v_normals).sum() + (conics * v_conics).sum() + ((compensations * v_compensations).sum() if calc_compensations else 0), (viewmats, quats, scales, means), @@ -264,6 +274,7 @@ def test_projection( _v_viewmats, _v_quats, _v_scales, _v_means = torch.autograd.grad( (_means2d * v_means2d).sum() + (_depths * v_depths).sum() + + (_normals * v_normals).sum() + (_conics * v_conics).sum() + ((_compensations * v_compensations).sum() if calc_compensations else 0), (viewmats, quats, scales, means), @@ -276,7 +287,7 @@ def test_projection( @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -@pytest.mark.parametrize("fused", [False, True]) +@pytest.mark.parametrize("fused", [True]) @pytest.mark.parametrize("sparse_grad", [False, True]) @pytest.mark.parametrize("calc_compensations", [False, True]) @pytest.mark.parametrize("camera_model", ["pinhole", "ortho", "fisheye"]) @@ -312,6 +323,7 @@ def test_fully_fused_projection_packed( radii, means2d, depths, + normals, conics, compensations, ) = fully_fused_projection( @@ -328,7 +340,14 @@ def test_fully_fused_projection_packed( calc_compensations=calc_compensations, camera_model=camera_model, ) - _radii, _means2d, _depths, _conics, _compensations = fully_fused_projection( + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = fully_fused_projection( means, None, quats, @@ -349,6 +368,7 @@ def test_fully_fused_projection_packed( radii, means2d, depths, + normals, conics, compensations, ) = fully_fused_projection( @@ -365,7 +385,14 @@ def test_fully_fused_projection_packed( calc_compensations=calc_compensations, camera_model=camera_model, ) - _radii, _means2d, _depths, _conics, _compensations = fully_fused_projection( + ( + _radii, + _means2d, + _depths, + _normals, + _conics, + _compensations, + ) = fully_fused_projection( means, covars, None, @@ -389,6 +416,9 @@ def test_fully_fused_projection_packed( __depths = torch.sparse_coo_tensor( torch.stack([camera_ids, gaussian_ids]), depths, _depths.shape ).to_dense() + __normals = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), normals, _normals.shape + ).to_dense() __conics = torch.sparse_coo_tensor( torch.stack([camera_ids, gaussian_ids]), conics, _conics.shape ).to_dense() @@ -400,6 +430,7 @@ def test_fully_fused_projection_packed( torch.testing.assert_close(__radii[sel], _radii[sel], rtol=0, atol=1) torch.testing.assert_close(__means2d[sel], _means2d[sel], rtol=1e-4, atol=1e-4) torch.testing.assert_close(__depths[sel], _depths[sel], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(__normals[sel], _normals[sel], rtol=1e-4, atol=1e-4) torch.testing.assert_close(__conics[sel], _conics[sel], rtol=1e-4, atol=1e-4) if calc_compensations: torch.testing.assert_close( @@ -409,10 +440,12 @@ def test_fully_fused_projection_packed( # backward v_means2d = torch.randn_like(_means2d) * sel[..., None] v_depths = torch.randn_like(_depths) * sel + v_normals = torch.randn_like(_normals) * sel[..., None] v_conics = torch.randn_like(_conics) * sel[..., None] _v_viewmats, _v_quats, _v_scales, _v_means = torch.autograd.grad( (_means2d * v_means2d).sum() + (_depths * v_depths).sum() + + (_normals * v_normals).sum() + (_conics * v_conics).sum(), (viewmats, quats, scales, means), retain_graph=True, @@ -420,6 +453,7 @@ def test_fully_fused_projection_packed( v_viewmats, v_quats, v_scales, v_means = torch.autograd.grad( (means2d * v_means2d[__radii > 0]).sum() + (depths * v_depths[__radii > 0]).sum() + + (normals * v_normals[__radii > 0]).sum() + (conics * v_conics[__radii > 0]).sum(), (viewmats, quats, scales, means), retain_graph=True, @@ -476,7 +510,6 @@ def test_rasterize_to_pixels(test_data, channels: int): fully_fused_projection, isect_offset_encode, isect_tiles, - quat_scale_to_covar_preci, rasterize_to_pixels, ) @@ -494,11 +527,9 @@ def test_rasterize_to_pixels(test_data, channels: int): colors = torch.randn(C, len(means), channels, device=device) backgrounds = torch.rand((C, colors.shape[-1]), device=device) - covars, _ = quat_scale_to_covar_preci(quats, scales, compute_preci=False, triu=True) - # Project Gaussians to 2D - radii, means2d, depths, conics, compensations = fully_fused_projection( - means, covars, None, None, viewmats, Ks, width, height + radii, means2d, depths, normals, conics, compensations = fully_fused_projection( + means, None, quats, scales, viewmats, Ks, width, height ) opacities = opacities.repeat(C, 1) diff --git a/tests/test_rasterization.py b/tests/test_rasterization.py index 1aad8738b..9eb6170da 100644 --- a/tests/test_rasterization.py +++ b/tests/test_rasterization.py @@ -17,7 +17,7 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") @pytest.mark.parametrize("per_view_color", [True, False]) @pytest.mark.parametrize("sh_degree", [None, 3]) -@pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D"]) +@pytest.mark.parametrize("render_mode", ["RGB", "RGB+D", "D", "RGB+ED+N"]) @pytest.mark.parametrize("packed", [True, False]) def test_rasterization( per_view_color: bool, sh_degree: Optional[int], render_mode: str, packed: bool