From 0a200cdb8293a035b804b87e4ae6987cb3b4222b Mon Sep 17 00:00:00 2001 From: Koki Hokao <86301305+khokao@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:40:07 +0900 Subject: [PATCH] Fix 2DGS for Correct Normal and Median Depth Calculation (#430) * Fix: Add missing camera offsets for render outputs in 2DGS kernel * Fix: Correct render_normals transformation * black formatting --- gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu | 4 ++++ gsplat/rendering.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu index 839a72f8..c36d7605 100644 --- a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu @@ -72,6 +72,10 @@ __global__ void rasterize_to_pixels_fwd_2dgs_kernel( render_colors += camera_id * image_height * image_width * COLOR_DIM; // get the global offset of the pixel w.r.t the camera render_alphas += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera last_ids += camera_id * image_height * image_width; // get the global offset of the pixel w.r.t the camera + render_normals += camera_id * image_height * image_width * 3; + render_distort += camera_id * image_height * image_width; + render_median += camera_id * image_height * image_width; + median_ids += camera_id * image_height * image_width; // get the global offset of the background and mask if (backgrounds != nullptr) { diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 83ec6e77..78da64ab 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -1306,7 +1306,9 @@ def rasterization_2dgs( "gradient_2dgs": densify, # This holds the gradient used for densification for 2dgs } - render_normals = render_normals @ torch.linalg.inv(viewmats)[0, :3, :3].T + render_normals = torch.einsum( + "...ij,...hwj->...hwi", torch.linalg.inv(viewmats)[..., :3, :3], render_normals + ) return ( render_colors,