diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 2cf22fe02..908b4eb73 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -6,13 +6,13 @@ from typing import Dict, List, Optional, Tuple import imageio +import nerfview import numpy as np import torch import torch.nn.functional as F import tqdm import tyro import viser -import nerfview from datasets.colmap import Dataset, Parser from datasets.traj import generate_interpolated_path from torch import Tensor @@ -141,6 +141,11 @@ class Config: # Weight for depth loss depth_lambda: float = 1e-2 + # Distoration loss. (experimental) + dist_loss: bool = False + # Weight for distortion loss + dist_lambda: float = 1e-3 + # Dump information to tensorboard every this steps tb_every: int = 100 # Save training images to tensorboard @@ -472,7 +477,8 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode="RGB+ED" if (cfg.depth_loss or cfg.dist_loss) else "RGB", + distloss=self.cfg.dist_loss, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -510,12 +516,17 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda + if cfg.dist_loss: + distloss = info["render_distloss"].mean() + loss += distloss * cfg.dist_lambda loss.backward() desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " if cfg.depth_loss: desc += f"depth loss={depthloss.item():.6f}| " + if cfg.dist_loss: + desc += f"dist loss={distloss.item():.6f}| " if cfg.pose_opt and cfg.pose_noise: # monitor the pose error if we inject noise pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) @@ -533,6 +544,8 @@ def train(self): self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.dist_loss: + self.writer.add_scalar("train/distloss", distloss.item(), step) if cfg.tb_save_image: canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() canvas = canvas.reshape(-1, *canvas.shape[2:]) diff --git a/examples/simple_trainer_mcmc.py b/examples/simple_trainer_mcmc.py index 8c6a21924..86af8945b 100644 --- a/examples/simple_trainer_mcmc.py +++ b/examples/simple_trainer_mcmc.py @@ -6,29 +6,25 @@ from typing import Dict, List, Optional, Tuple import imageio +import nerfview import numpy as np import torch import torch.nn.functional as F import tqdm import tyro import viser -import nerfview from datasets.colmap import Dataset, Parser from datasets.traj import generate_interpolated_path +from simple_trainer import create_splats_with_optimizers from torch import Tensor from torch.utils.tensorboard import SummaryWriter from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -from utils import ( - AppearanceOptModule, - CameraOptModule, - set_random_seed, -) +from utils import AppearanceOptModule, CameraOptModule, set_random_seed + from gsplat import quat_scale_to_covar_preci -from gsplat.rendering import rasterization from gsplat.relocation import compute_relocation -from gsplat.cuda_legacy._torch_impl import scale_rot_to_cov3d -from simple_trainer import create_splats_with_optimizers +from gsplat.rendering import rasterization @dataclass @@ -139,6 +135,11 @@ class Config: # Weight for depth loss depth_lambda: float = 1e-2 + # Distoration loss. (experimental) + dist_loss: bool = False + # Weight for distortion loss + dist_lambda: float = 1e-3 + # Dump information to tensorboard every this steps tb_every: int = 100 # Save training images to tensorboard @@ -394,7 +395,8 @@ def train(self): near_plane=cfg.near_plane, far_plane=cfg.far_plane, image_ids=image_ids, - render_mode="RGB+ED" if cfg.depth_loss else "RGB", + render_mode="RGB+ED" if (cfg.depth_loss or cfg.dist_loss) else "RGB", + distloss=self.cfg.dist_loss, ) if renders.shape[-1] == 4: colors, depths = renders[..., 0:3], renders[..., 3:4] @@ -432,6 +434,9 @@ def train(self): disp_gt = 1.0 / depths_gt # [1, M] depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale loss += depthloss * cfg.depth_lambda + if cfg.dist_loss: + distloss = info["render_distloss"].mean() + loss += distloss * cfg.dist_lambda loss = ( loss @@ -448,6 +453,8 @@ def train(self): desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " if cfg.depth_loss: desc += f"depth loss={depthloss.item():.6f}| " + if cfg.dist_loss: + desc += f"dist loss={distloss.item():.6f}| " if cfg.pose_opt and cfg.pose_noise: # monitor the pose error if we inject noise pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index e73cdd86a..e37fdfd71 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -309,7 +309,8 @@ def accumulate( camera_ids: Tensor, # [M] image_width: int, image_height: int, -) -> Tuple[Tensor, Tensor]: + distloss: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: """Alpah compositing of 2D Gaussians in Pure Pytorch. This function performs alpha compositing for Gaussians based on the pair of indices @@ -334,23 +335,30 @@ def accumulate( means2d: Gaussian means in 2D. [C, N, 2] conics: Inverse of the 2D Gaussian covariance, Only upper triangle values. [C, N, 3] opacities: Per-view Gaussian opacities (for example, when antialiasing is - enabled, Gaussian in each view would efficiently have different opacity). [C, N] + enabled, Gaussian in each view would efficiently have different opacity). [C, N] colors: Per-view Gaussian colors. Supports N-D features. [C, N, channels] gaussian_ids: Collection of Gaussian indices to be rasterized. A flattened list of shape [M]. pixel_ids: Collection of pixel indices (row-major) to be rasterized. A flattened list of shape [M]. camera_ids: Collection of camera indices to be rasterized. A flattened list of shape [M]. image_width: Image width. image_height: Image height. + distloss: If True, render the per-pixel distortion loss map. This requires the depth + information in the last channel of `colors`. Default is False. Returns: A tuple: - **renders**: Accumulated colors. [C, image_height, image_width, channels] - **alphas**: Accumulated opacities. [C, image_height, image_width, 1] + - **dists**: Distortion loss if `distloss=True` else None. [C, image_height, image_width, 1] """ try: - from nerfacc import accumulate_along_rays, render_weight_from_alpha + from nerfacc import ( + accumulate_along_rays, + exclusive_sum, + render_weight_from_alpha, + ) except ImportError: raise ImportError("Please install nerfacc package: pip install nerfacc") @@ -386,7 +394,18 @@ def accumulate( weights, None, ray_indices=indices, n_rays=total_pixels ).reshape(C, image_height, image_width, 1) - return renders, alphas + if distloss: + depths = colors[camera_ids, gaussian_ids, -1] + loss_bi_0 = weights * depths * exclusive_sum(weights, indices=indices) + loss_bi_1 = weights * exclusive_sum(weights * depths, indices=indices) + dists = 2 * (loss_bi_0 - loss_bi_1) + dists = accumulate_along_rays(dists, None, indices, total_pixels).reshape( + C, image_height, image_width, 1 + ) + else: + dists = None + + return renders, alphas, dists def _rasterize_to_pixels( @@ -400,6 +419,7 @@ def _rasterize_to_pixels( isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] backgrounds: Optional[Tensor] = None, # [C, channels] + distloss: bool = False, batch_per_iter: int = 100, ): """Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels()`. @@ -434,6 +454,10 @@ def _rasterize_to_pixels( (C, image_height, image_width, colors.shape[-1]), device=device ) render_alphas = torch.zeros((C, image_height, image_width, 1), device=device) + if distloss: + render_distloss = torch.zeros((C, image_height, image_width, 1), device=device) + else: + render_distloss = None # Split Gaussians into batches and iteratively accumulate the renderings block_size = tile_size * tile_size @@ -464,7 +488,7 @@ def _rasterize_to_pixels( break # Accumulate the renderings within this batch of Gaussians. - renders_step, accs_step = accumulate( + renders_step, accs_step, dists_step = accumulate( means2d, conics, opacities, @@ -474,17 +498,21 @@ def _rasterize_to_pixels( camera_ids, image_width, image_height, + distloss=distloss, ) render_colors = render_colors + renders_step * transmittances[..., None] render_alphas = render_alphas + accs_step * transmittances[..., None] + if distloss: + render_distloss = render_distloss + dists_step * ( + transmittances[..., None] ** 2 + ) - render_alphas = render_alphas if backgrounds is not None: render_colors = render_colors + backgrounds[:, None, None, :] * ( 1.0 - render_alphas ) - return render_colors, render_alphas + return render_colors, render_alphas, render_distloss def _eval_sh_bases_fast(basis_dim: int, dirs: Tensor): diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index 9043181ce..e0e1a3caa 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -381,6 +381,7 @@ def rasterize_to_pixels( backgrounds: Optional[Tensor] = None, # [C, channels] packed: bool = False, absgrad: bool = False, + distloss: bool = False, ) -> Tuple[Tensor, Tensor]: """Rasterizes Gaussians to pixels. @@ -397,12 +398,15 @@ def rasterize_to_pixels( backgrounds: Background colors. [C, channels]. Default: None. packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. + distloss: If True, the per-pixel distortion loss will be rendered. We expect + the depths are concatenated into the last channel of the colors. Default: False. Returns: A tuple: - **Rendered colors**. [C, image_height, image_width, channels] - **Rendered alphas**. [C, image_height, image_width, 1] + - **Rendered distloss**. [C, image_height, image_width, 1] if `distloss` is True, otherwise None. """ C = isect_offsets.size(0) @@ -478,7 +482,7 @@ def rasterize_to_pixels( tile_width * tile_size >= image_width ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" - render_colors, render_alphas = _RasterizeToPixels.apply( + render_colors, render_alphas, render_distloss = _RasterizeToPixels.apply( means2d.contiguous(), conics.contiguous(), colors.contiguous(), @@ -490,11 +494,12 @@ def rasterize_to_pixels( isect_offsets.contiguous(), flatten_ids.contiguous(), absgrad, + distloss, ) if padded_channels > 0: render_colors = render_colors[..., :-padded_channels] - return render_colors, render_alphas + return render_colors, render_alphas, render_distloss @torch.no_grad() @@ -820,8 +825,9 @@ def forward( isect_offsets: Tensor, # [C, tile_height, tile_width] flatten_ids: Tensor, # [n_isects] absgrad: bool, - ) -> Tuple[Tensor, Tensor]: - render_colors, render_alphas, last_ids = _make_lazy_cuda_func( + distloss: bool, + ) -> Tuple[Tensor, Tensor, Tensor]: + render_colors, render_alphas, render_distloss, last_ids = _make_lazy_cuda_func( "rasterize_to_pixels_fwd" )( means2d, @@ -834,6 +840,7 @@ def forward( tile_size, isect_offsets, flatten_ids, + distloss, ) ctx.save_for_backward( @@ -844,6 +851,7 @@ def forward( backgrounds, isect_offsets, flatten_ids, + render_colors, render_alphas, last_ids, ) @@ -851,16 +859,18 @@ def forward( ctx.height = height ctx.tile_size = tile_size ctx.absgrad = absgrad + ctx.distloss = distloss # double to float render_alphas = render_alphas.float() - return render_colors, render_alphas + return render_colors, render_alphas, render_distloss @staticmethod def backward( ctx, v_render_colors: Tensor, # [C, H, W, 3] v_render_alphas: Tensor, # [C, H, W, 1] + v_render_distloss: Optional[Tensor], # [C, H, W, 1] ): ( means2d, @@ -870,6 +880,7 @@ def backward( backgrounds, isect_offsets, flatten_ids, + render_colors, render_alphas, last_ids, ) = ctx.saved_tensors @@ -877,6 +888,12 @@ def backward( height = ctx.height tile_size = ctx.tile_size absgrad = ctx.absgrad + distloss = ctx.distloss + if distloss: + assert v_render_distloss is not None, "v_render_distloss should not be None" + v_render_distloss = v_render_distloss.contiguous() + else: + assert v_render_distloss is None, "v_render_distloss should be None" ( v_means2d_abs, @@ -895,10 +912,12 @@ def backward( tile_size, isect_offsets, flatten_ids, + render_colors, render_alphas, last_ids, v_render_colors.contiguous(), v_render_alphas.contiguous(), + v_render_distloss, absgrad, ) @@ -924,6 +943,7 @@ def backward( None, None, None, + None, ) diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 7af30f56a..595670b8a 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -114,27 +114,29 @@ torch::Tensor isect_offset_encode_tensor(const torch::Tensor &isect_ids, // [n_i const uint32_t C, const uint32_t tile_width, const uint32_t tile_height); -std::tuple rasterize_to_pixels_fwd_tensor( +std::tuple +rasterize_to_pixels_fwd_tensor( // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &colors, // [C, N, D] - const torch::Tensor &opacities, // [N] - const at::optional &backgrounds, // [C, D] + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const at::optional &backgrounds, // [C, channels] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &flatten_ids // [n_isects] -); + const torch::Tensor &flatten_ids, // [n_isects] + // options + const bool distloss); std::tuple rasterize_to_pixels_bwd_tensor( // Gaussian parameters - const torch::Tensor &means2d, // [C, N, 2] - const torch::Tensor &conics, // [C, N, 3] - const torch::Tensor &colors, // [C, N, 3] - const torch::Tensor &opacities, // [N] + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] const at::optional &backgrounds, // [C, 3] // image size const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, @@ -142,11 +144,14 @@ rasterize_to_pixels_bwd_tensor( const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] // forward outputs + const torch::Tensor &render_colors, // [C, image_height, image_width, COLOR_DIM] const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] const torch::Tensor &last_ids, // [C, image_height, image_width] // gradients of outputs const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const at::optional + &v_render_distloss, // [C, image_height, image_width, 1] // options bool absgrad); @@ -216,10 +221,6 @@ fully_fused_projection_packed_bwd_tensor( const bool viewmats_requires_grad, const bool sparse_grad); std::tuple -compute_relocation_tensor( - torch::Tensor& opacities, - torch::Tensor& scales, - torch::Tensor& ratios, - torch::Tensor& binoms, - const int n_max -); +compute_relocation_tensor(torch::Tensor &opacities, torch::Tensor &scales, + torch::Tensor &ratios, torch::Tensor &binoms, + const int n_max); diff --git a/gsplat/cuda/csrc/rasterization.cu b/gsplat/cuda/csrc/rasterization.cu index ddb8856a3..6db92641f 100644 --- a/gsplat/cuda/csrc/rasterization.cu +++ b/gsplat/cuda/csrc/rasterization.cu @@ -529,7 +529,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( const int32_t *__restrict__ flatten_ids, // [n_isects] float *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] float *__restrict__ render_alphas, // [C, image_height, image_width, 1] - int32_t *__restrict__ last_ids // [C, image_height, image_width] + // if render distort maps, the last channel of render_colors should be depth + float *__restrict__ render_distloss, // [C, image_height, image_width, 1] optional + int32_t *__restrict__ last_ids // [C, image_height, image_width] ) { // each thread draws one pixel, but also timeshares caching gaussians in a // shared tile @@ -547,6 +549,9 @@ __global__ void rasterize_to_pixels_fwd_kernel( if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } + if (render_distloss != nullptr) { + render_distloss += camera_id * image_height * image_width; + } float px = (float)j + 0.5f; float py = (float)i + 0.5f; @@ -586,6 +591,12 @@ __global__ void rasterize_to_pixels_fwd_kernel( // designated pixel uint32_t tr = block.thread_rank(); + // Per-pixel distortion error proposed in Mip-NeRF 360. + // Implemented reference: + // https://github.com/nerfstudio-project/nerfacc/blob/master/nerfacc/losses.py#L7 + float distort = 0.f; + float accum_vis_depth = 0.f; // accumulated vis * depth + float pix_out[COLOR_DIM] = {0.f}; for (uint32_t b = 0; b < num_batches; ++b) { // resync all threads before beginning next batch @@ -638,6 +649,17 @@ __global__ void rasterize_to_pixels_fwd_kernel( for (uint32_t k = 0; k < COLOR_DIM; ++k) { pix_out[k] += c_ptr[k] * vis; } + if (render_distloss != nullptr) { + // the last channel of colors is depth + const float depth = c_ptr[COLOR_DIM - 1]; + // in nerfacc, loss_bi_0 = weights * t_mids * exclusive_sum(weights) + const float distort_bi_0 = vis * depth * (1.0f - T); + // in nerfacc, loss_bi_1 = weights * exclusive_sum(weights * t_mids) + const float distort_bi_1 = vis * accum_vis_depth; + distort += 2.0f * (distort_bi_0 - distort_bi_1); + accum_vis_depth += vis * depth; + } + cur_idx = batch_start + t; T = next_T; @@ -658,10 +680,14 @@ __global__ void rasterize_to_pixels_fwd_kernel( } // index in bin of last gaussian in this pixel last_ids[pix_id] = static_cast(cur_idx); + if (render_distloss != nullptr) { + render_distloss[pix_id] = distort; + } } } -std::tuple rasterize_to_pixels_fwd_tensor( +std::tuple +rasterize_to_pixels_fwd_tensor( // Gaussian parameters const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] const torch::Tensor &conics, // [C, N, 3] or [nnz, 3] @@ -672,8 +698,9 @@ std::tuple rasterize_to_pixels_fwd_ const uint32_t image_width, const uint32_t image_height, const uint32_t tile_size, // intersections const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] - const torch::Tensor &flatten_ids // [n_isects] -) { + const torch::Tensor &flatten_ids, // [n_isects] + // options + const bool distloss) { DEVICE_GUARD(means2d); CHECK_INPUT(means2d); CHECK_INPUT(conics); @@ -702,6 +729,11 @@ std::tuple rasterize_to_pixels_fwd_ means2d.options().dtype(torch::kFloat32)); torch::Tensor alphas = torch::empty({C, image_height, image_width, 1}, means2d.options().dtype(torch::kFloat32)); + torch::Tensor distortions; + if (distloss) { + distortions = torch::empty({C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32)); + } torch::Tensor last_ids = torch::empty({C, image_height, image_width}, means2d.options().dtype(torch::kInt32)); @@ -728,6 +760,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 2: @@ -745,6 +778,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 3: @@ -762,6 +796,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 4: @@ -779,6 +814,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 5: @@ -796,6 +832,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 8: @@ -813,6 +850,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 9: @@ -830,6 +868,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 16: @@ -847,6 +886,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 17: @@ -864,6 +904,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 32: @@ -881,6 +922,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 33: @@ -898,6 +940,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 64: @@ -915,6 +958,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 65: @@ -932,6 +976,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 128: @@ -949,6 +994,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 129: @@ -966,6 +1012,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 256: @@ -983,6 +1030,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 257: @@ -1000,6 +1048,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 512: @@ -1017,6 +1066,7 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; case 513: @@ -1034,12 +1084,13 @@ std::tuple rasterize_to_pixels_fwd_ image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), renders.data_ptr(), alphas.data_ptr(), + distloss ? distortions.data_ptr() : nullptr, last_ids.data_ptr()); break; default: AT_ERROR("Unsupported number of channels: ", channels); } - return std::make_tuple(renders, alphas, last_ids); + return std::make_tuple(renders, alphas, distortions, last_ids); } template @@ -1056,12 +1107,16 @@ __global__ void rasterize_to_pixels_bwd_kernel( const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] const int32_t *__restrict__ flatten_ids, // [n_isects] // fwd outputs + const float + *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] const float *__restrict__ render_alphas, // [C, image_height, image_width, 1] const int32_t *__restrict__ last_ids, // [C, image_height, image_width] // grad outputs const float *__restrict__ v_render_colors, // [C, image_height, image_width, COLOR_DIM] const float *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] + const float + *__restrict__ v_render_distloss, // [C, image_height, image_width, 1] optional // grad inputs float2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] float2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] @@ -1076,6 +1131,7 @@ __global__ void rasterize_to_pixels_bwd_kernel( uint32_t j = block.group_index().z * tile_size + block.thread_index().x; tile_offsets += camera_id * tile_height * tile_width; + render_colors += camera_id * image_height * image_width * COLOR_DIM; render_alphas += camera_id * image_height * image_width; last_ids += camera_id * image_height * image_width; v_render_colors += camera_id * image_height * image_width * COLOR_DIM; @@ -1083,6 +1139,9 @@ __global__ void rasterize_to_pixels_bwd_kernel( if (backgrounds != nullptr) { backgrounds += camera_id * COLOR_DIM; } + if (v_render_distloss != nullptr) { + v_render_distloss += camera_id * image_height * image_width; + } const float px = (float)j + 0.5f; const float py = (float)i + 0.5f; @@ -1126,6 +1185,20 @@ __global__ void rasterize_to_pixels_bwd_kernel( } const float v_render_a = v_render_alphas[pix_id]; + // prepare for distortion + float v_distort = 0.f; + float accum_d, accum_w; + float accum_d_buffer, accum_w_buffer, distort_buffer; + if (v_render_distloss != nullptr) { + v_distort = v_render_distloss[pix_id]; + // last channel of render_colors is accumulated depth + accum_d_buffer = render_colors[pix_id * COLOR_DIM + COLOR_DIM - 1]; + accum_d = accum_d_buffer; + accum_w_buffer = render_alphas[pix_id]; + accum_w = accum_w_buffer; + distort_buffer = 0.f; + } + // collect and process batches of gaussians // each thread loads one gaussian at a time before rasterizing const uint32_t tr = block.thread_rank(); @@ -1223,6 +1296,23 @@ __global__ void rasterize_to_pixels_bwd_kernel( v_alpha += -T_final * ra * accum; } + // contribution from distortion + if (v_render_distloss != nullptr) { + // last channel of colors is depth + float depth = rgbs_batch[t * COLOR_DIM + COLOR_DIM - 1]; + float dl_dw = + 2.0f * (2.0f * (depth * accum_w_buffer - accum_d_buffer) + + (accum_d - depth * accum_w)); + // df / d(alpha) + v_alpha += (dl_dw * T - distort_buffer * ra) * v_distort; + accum_d_buffer -= fac * depth; + accum_w_buffer -= fac; + distort_buffer += dl_dw * fac; + // df / d(depth). put it in the last channel of v_rgb + v_rgb_local[COLOR_DIM - 1] += + 2.0f * fac * (2.0f - 2.0f * T - accum_w + fac) * v_distort; + } + if (opac * vis <= 0.999f) { const float v_sigma = -opac * vis * v_alpha; v_conic_local = {0.5f * v_sigma * delta.x * delta.x, @@ -1291,11 +1381,14 @@ rasterize_to_pixels_bwd_tensor( const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] const torch::Tensor &flatten_ids, // [n_isects] // forward outputs + const torch::Tensor &render_colors, // [C, image_height, image_width, COLOR_DIM] const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] const torch::Tensor &last_ids, // [C, image_height, image_width] // gradients of outputs const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const at::optional + &v_render_distloss, // [C, image_height, image_width, 1] // options bool absgrad) { DEVICE_GUARD(means2d); @@ -1312,6 +1405,9 @@ rasterize_to_pixels_bwd_tensor( if (backgrounds.has_value()) { CHECK_INPUT(backgrounds.value()); } + if (v_render_distloss.has_value()) { + CHECK_INPUT(v_render_distloss.value()); + } bool packed = means2d.dim() == 2; @@ -1357,8 +1453,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1379,8 +1479,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1401,8 +1505,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1423,8 +1531,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1445,8 +1557,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1467,8 +1583,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1489,8 +1609,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1511,8 +1635,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1533,8 +1661,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1555,8 +1687,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1577,8 +1713,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1599,8 +1739,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1621,8 +1765,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1644,9 +1792,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1668,9 +1819,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1692,9 +1846,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1716,9 +1873,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1740,9 +1900,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), @@ -1764,9 +1927,12 @@ rasterize_to_pixels_bwd_tensor( : nullptr, image_width, image_height, tile_size, tile_width, tile_height, tile_offsets.data_ptr(), flatten_ids.data_ptr(), - render_alphas.data_ptr(), last_ids.data_ptr(), - v_render_colors.data_ptr(), + render_colors.data_ptr(), render_alphas.data_ptr(), + last_ids.data_ptr(), v_render_colors.data_ptr(), v_render_alphas.data_ptr(), + v_render_distloss.has_value() + ? v_render_distloss.value().data_ptr() + : nullptr, absgrad ? (float2 *)v_means2d_abs.data_ptr() : nullptr, (float2 *)v_means2d.data_ptr(), (float3 *)v_conics.data_ptr(), v_colors.data_ptr(), diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 7339e5b2f..cf56b2a19 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -36,6 +36,7 @@ def rasterization( sparse_grad: bool = False, absgrad: bool = False, rasterize_mode: Literal["classic", "antialiased"] = "classic", + distloss: bool = False, channel_chunk: int = 32, ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). @@ -106,6 +107,13 @@ def rasterization( `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, which is shown to be more effective for splitting Gaussians during training. + .. note:: + **Distortion Map**: If `distloss` is True, the function will render out the per-pixel + distortion loss, following the formulation in the paper `Mip-NeRF 360: Unbounded + Anti-Aliased Neural Radiance Fields `_. This requires + depth rendering, so the `render_mode` should be "D", "ED", "RGB+D", or "RGB+ED". The + distortion map will be stored in the meta as `meta["render_distloss"]`. + .. warning:: This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. @@ -145,6 +153,9 @@ def rasterization( `meta["means2d"].absgrad`. Default is False. rasterize_mode: The rasterization mode. Supported modes are "classic" and "antialiased". Default is "classic". + distloss: If true, the function will render out the distortions map as well + and store it in the meta. This requires depth rendering, so the render_mode + should be "D", "ED", "RGB+D", or "RGB+ED". Default is False. channel_chunk: The number of channels to render in one go. Default is 32. If the required rendering channels are larger than this value, the rendering will be done looply in chunks. @@ -198,6 +209,13 @@ def rasterization( assert viewmats.shape == (C, 4, 4), viewmats.shape assert Ks.shape == (C, 3, 3), Ks.shape assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + if distloss: + assert render_mode in [ + "D", + "ED", + "RGB+D", + "RGB+ED", + ], f"distloss requires depth rendering, render_mode should be D, ED, RGB+D, or RGB+ED, but got {render_mode}" if sh_degree is None: # treat colors as post-activation values, should be in shape [N, D] or [C, N, D] @@ -330,7 +348,7 @@ def rasterization( if colors.shape[-1] > channel_chunk: # slice into chunks n_chunks = (colors.shape[-1] + channel_chunk - 1) // channel_chunk - render_colors, render_alphas = [], [] + render_colors = [] for i in range(n_chunks): colors_chunk = colors[..., i * channel_chunk : (i + 1) * channel_chunk] backgrounds_chunk = ( @@ -338,7 +356,10 @@ def rasterization( if backgrounds is not None else None ) - render_colors_, render_alphas_ = rasterize_to_pixels( + # If distloss=True, we expect the last channel of colors is depths, + # from which we compute the distortions. So the only last chunk here produces + # a valid render_distloss output. + render_colors_, render_alphas, render_distloss = rasterize_to_pixels( means2d, conics, colors_chunk, @@ -351,13 +372,12 @@ def rasterization( backgrounds=backgrounds_chunk, packed=packed, absgrad=absgrad, + distloss=distloss, ) render_colors.append(render_colors_) - render_alphas.append(render_alphas_) render_colors = torch.cat(render_colors, dim=-1) - render_alphas = render_alphas[0] # discard the rest else: - render_colors, render_alphas = rasterize_to_pixels( + render_colors, render_alphas, render_distloss = rasterize_to_pixels( means2d, conics, colors, @@ -370,6 +390,7 @@ def rasterization( backgrounds=backgrounds, packed=packed, absgrad=absgrad, + distloss=distloss, ) if render_mode in ["ED", "RGB+ED"]: # normalize the accumulated depth to get the expected depth @@ -398,6 +419,7 @@ def rasterization( "width": width, "height": height, "tile_size": tile_size, + "render_distloss": render_distloss, } return render_colors, render_alphas, meta diff --git a/tests/test_basic.py b/tests/test_basic.py index 8c546b450..081f63316 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -439,7 +439,8 @@ def test_isect(test_data): @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") @pytest.mark.parametrize("channels", [3, 32, 128]) -def test_rasterize_to_pixels(test_data, channels: int): +@pytest.mark.parametrize("distloss", [False, True]) +def test_rasterize_to_pixels(test_data, channels: int, distloss: bool): from gsplat.cuda._torch_impl import _rasterize_to_pixels from gsplat.cuda._wrapper import ( fully_fused_projection, @@ -481,6 +482,12 @@ def test_rasterize_to_pixels(test_data, channels: int): ) isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + if distloss: + colors = torch.cat([colors, depths[..., None]], dim=-1) + backgrounds = torch.cat( + [backgrounds, torch.zeros_like(backgrounds[..., :1])], dim=-1 + ) + means2d.requires_grad = True conics.requires_grad = True colors.requires_grad = True @@ -488,7 +495,7 @@ def test_rasterize_to_pixels(test_data, channels: int): backgrounds.requires_grad = True # forward - render_colors, render_alphas = rasterize_to_pixels( + render_colors, render_alphas, render_distloss = rasterize_to_pixels( means2d, conics, colors, @@ -499,8 +506,9 @@ def test_rasterize_to_pixels(test_data, channels: int): isect_offsets, flatten_ids, backgrounds=backgrounds, + distloss=distloss, ) - _render_colors, _render_alphas = _rasterize_to_pixels( + _render_colors, _render_alphas, _render_distloss = _rasterize_to_pixels( means2d, conics, colors, @@ -511,18 +519,26 @@ def test_rasterize_to_pixels(test_data, channels: int): isect_offsets, flatten_ids, backgrounds=backgrounds, + distloss=distloss, ) torch.testing.assert_close(render_colors, _render_colors) torch.testing.assert_close(render_alphas, _render_alphas) + torch.testing.assert_close(render_distloss, _render_distloss) # backward v_render_colors = torch.randn_like(render_colors) v_render_alphas = torch.randn_like(render_alphas) + if distloss: + v_render_distloss = torch.randn_like(render_distloss) v_means2d, v_conics, v_colors, v_opacities, v_backgrounds = torch.autograd.grad( - (render_colors * v_render_colors).sum() - + (render_alphas * v_render_alphas).sum(), + ( + (render_colors * v_render_colors).sum() + + (render_alphas * v_render_alphas).sum() + + ((render_distloss * v_render_distloss).sum() if distloss else 0) + ), (means2d, conics, colors, opacities, backgrounds), + retain_graph=True, ) ( _v_means2d, @@ -531,9 +547,13 @@ def test_rasterize_to_pixels(test_data, channels: int): _v_opacities, _v_backgrounds, ) = torch.autograd.grad( - (_render_colors * v_render_colors).sum() - + (_render_alphas * v_render_alphas).sum(), + ( + (_render_colors * v_render_colors).sum() + + (_render_alphas * v_render_alphas).sum() + + ((render_distloss * v_render_distloss).sum() if distloss else 0) + ), (means2d, conics, colors, opacities, backgrounds), + retain_graph=True, ) torch.testing.assert_close(v_means2d, _v_means2d, rtol=5e-3, atol=5e-3) torch.testing.assert_close(v_conics, _v_conics, rtol=1e-3, atol=1e-3)