diff --git a/docs/source/apis/rasterization.rst b/docs/source/apis/rasterization.rst index f5f6d7eb3..0d0831953 100644 --- a/docs/source/apis/rasterization.rst +++ b/docs/source/apis/rasterization.rst @@ -1,6 +1,9 @@ Rasterization =================================== +3DGS +------ + .. currentmodule:: gsplat Given a set of 3D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, covariances @@ -38,4 +41,36 @@ projection equation: Where :math:`[W | t]` is the world-to-camera transformation matrix, and :math:`f_{x}, f_{y}` are the focal lengths of the camera. -.. autofunction:: rasterization \ No newline at end of file +.. autofunction:: rasterization + +2DGS +------ + +Given a set of 2D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, two principal tangent vectors +embedded as the first two columns of a rotation matrix :math:`R \in \mathbb{R}^{3\times3}`, and a scale matrix :math:`S \in R^{3\times3}` +representing the scaling along the two principal tangential directions, we first transforms pixels into splats' local tangent frame +by :math:`(WH)^{-1} \in \mathbb{R}^{4\times4}` and compute weights via ray-splat intersection. Then we follow the sort and rendering similar to 3DGS. + +Note that H is the transformation from splat's local tangent plane :math:`\{u, v\}` into world space + +.. math:: + + H = \begin{bmatrix} + RS & \mu \\ + 0 & 1 + \end{bmatrix} + +and :math:`W \in \mathbb{R}^{4\times4}` is the transformation matrix from world space to image space. + + +Splatting is done via ray-splat plane intersection. Each pixel is considered as a x-plane :math:`h_{x}=(-1, 0, 0, x)^{T}` +and a y-plane :math:`h_{y}=(0, -1, 0, y)^{T}`, and the intersection between a splat and the pixel :math:`p=(x, y)` is defined +as the intersection bwtween x-plane, y-plane, and the splat's tangent plane. We first transform :math:`h_{x}` to :math:`h_{u}` and :math:`h_{y}` +to :math:`h_{v}` in splat's tangent frame via the inverse transformation :math:`(WH)^{-1}`. As the intersection point should fall on :math:`h_{u}` and :math:`h_{v}`, we have an efficient +solution: + +.. math:: + u(p) = \frac{h^{2}_{u}h^{4}_{v}-h^{4}_{u}h^{2}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}}, + v(p) = \frac{h^{4}_{u}h^{1}_{v}-h^{1}_{u}h^{4}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}} + +.. autofunction:: rasterization_2dgs \ No newline at end of file diff --git a/docs/source/apis/utils.rst b/docs/source/apis/utils.rst index cd83393be..cc839dabe 100644 --- a/docs/source/apis/utils.rst +++ b/docs/source/apis/utils.rst @@ -4,6 +4,9 @@ Utils Below are the basic functions that supports the rasterization. +3DGS +----- + .. currentmodule:: gsplat .. autofunction:: spherical_harmonics @@ -27,3 +30,17 @@ Below are the basic functions that supports the rasterization. .. autofunction:: accumulate .. autofunction:: rasterization_inria_wrapper + +2DGS +----- +.. currentmodule:: gsplat + +.. autofunction:: fully_fused_projection_2dgs + +.. autofunction:: rasterize_to_pixels_2dgs + +.. autofunction:: rasterize_to_indices_in_range_2dgs + +.. autofunction:: accumulate_2dgs + +.. autofunction:: rasterization_2dgs_inria_wrapper \ No newline at end of file diff --git a/docs/source/tests/eval.rst b/docs/source/tests/eval.rst index 9f2a2fdc0..49c13525e 100644 --- a/docs/source/tests/eval.rst +++ b/docs/source/tests/eval.rst @@ -1,6 +1,9 @@ Evaluation =================================== +3DGS +---------------------------------------------- + .. table:: Performance on `Mip-NeRF 360 Captures `_ (Averaged Over 7 Scenes) +---------------------+-------+-------+-------+------------------+------------+ @@ -140,3 +143,100 @@ The evaluation of `inria-X` can be reproduced with our forked wersion of the official implementation at `here `_, with the command :code:`python full_eval_m360.py` (commit 36546ce). + +2DGS +---------------------------------------------- + +No Regularization +---------------------------------------------- + +.. table:: Performance on `Mip-NeRF 360 Captures `_ (Averaged Over 7 Scenes) + ++---------------------+-------+-------+-------+------------------+------------+ +| | PSNR | SSIM | LPIPS | Train Mem | Train Time | ++=====================+=======+=======+=======+==================+============+ +| inria-30k | 28.73 | 0.860 | 0.148 | 3.73 GB | 22m16s | ++---------------------+-------+-------+-------+------------------+------------+ +| gsplat-30k | 28.76 | 0.867 | 0.145 | **3.70 GB** | **15m44s** | ++---------------------+-------+-------+-------+------------------+------------+ + +With Normal Consistency and Distortion Regularization +------------------------------------------------------ + ++---------------------+-------+-------+-------+------------------+------------+ +| | PSNR | SSIM | LPIPS | Train Mem | Train Time | ++=====================+=======+=======+=======+==================+============+ +| inria-30k | 28.05 | 0.848 | 0.186 | 3.76 GB | 22m06s | ++---------------------+-------+-------+-------+------------------+------------+ +| gsplat-30k | 27.80 | 0.842 | 0.169 | **3.61 GB** | **16m44s** | ++---------------------+-------+-------+-------+------------------+------------+ + +Runtime and GPU Memory +---------------------------------------------- + ++-----------------+---------+--------+---------+--------+---------+--------+--------+ +| Train Mem (GB) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++=================+=========+========+=========+========+=========+========+========+ +| inria-30k |**6.74** | 2.27 | 2.06 | 4.79 | 2.25 | 2.40 |**5.58**| ++-----------------+---------+--------+---------+--------+---------+--------+--------+ +| gsplat-30k | 6.89 |**2.19**| **1.93**|**4.48**| **2.14**|**2.30**| 6.00 | ++-----------------+---------+--------+---------+--------+---------+--------+--------+ + ++-----------------+---------+--------+---------+--------+---------+--------+--------+ +| Train Time (s) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++=================+=========+========+=========+========+=========+========+========+ +| inria-30k | 1463 | 1237 | 1318 | 1298 | 1422 | 1314 | 1252 | ++-----------------+---------+--------+---------+--------+---------+--------+--------+ +| gsplat-30k |**1231** |**788** | **803**| **985**| **828**| **789**|**1057**| ++-----------------+---------+--------+---------+--------+---------+--------+--------+ + + +Reproduced Metrics +---------------------------------------------- + ++------------+---------+--------+---------+--------+---------+-------+-------+ +| PSNR | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++============+=========+========+=========+========+=========+=======+=======+ +| inria-30k | 24.92 | 31.87 | 28.78 | 26.88 | 31.08 | 31.21 | 26.36 | ++------------+---------+--------+---------+--------+---------+-------+-------+ +| gsplat-30k | 24.97 | 31.94 | 28.76 | 26.95 | 31.08 | 31.27 | 26.37 | ++------------+---------+--------+---------+--------+---------+-------+-------+ + ++------------+---------+--------+---------+--------+---------+-------+-------+ +| SSIM | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++============+=========+========+=========+========+=========+=======+=======+ +| inria-30k | 0.741 | 0.937 | 0.899 | 0.847 | 0.921 | 0.914 | 0.760 | ++------------+---------+--------+---------+--------+---------+-------+-------+ +| gsplat-30k | 0.764 | 0.937 | 0.899 | 0.849 | 0.921 | 0.915 | 0.761 | ++------------+---------+--------+---------+--------+---------+-------+-------+ + ++------------+---------+--------+---------+--------+---------+-------+-------+ +| LPIPS | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++============+=========+========+=========+========+=========+=======+=======+ +| inria-30k | 0.199 | 0.136 | 0.164 | 0.093 | 0.101 | 0.172 | 0.168 | ++------------+---------+--------+---------+--------+---------+-------+-------+ +| gsplat-30k | 0.189 | 0.134 | 0.162 | 0.091 | 0.101 | 0.169 | 0.166 | ++------------+---------+--------+---------+--------+---------+-------+-------+ + ++-----------------+---------+--------+---------+--------+---------+-------+-------+ +| Number of GSs | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump | ++=================+=========+========+=========+========+=========+=======+=======+ +| inria-30k | 3.97M | 0.91M | 0.72M | 2.79M | 0.85M | 1.01M | 3.27M | ++-----------------+---------+--------+---------+--------+---------+-------+-------+ +| gsplat-30k | 3.88M | 0.92M | 0.73M | 2.49M | 0.87M | 1.03M | 3.40M | ++-----------------+---------+--------+---------+--------+---------+-------+-------+ + +Note: Evaulations for 2DGS are conducted on a NVIDIA RTX 4090 GPU. The LPIPS metric is evaluated +using :code:`from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity`, which +is different from what's reported in the original paper that uses +:code:`from lpipsPyTorch import lpips`. + +The evaluation of `gsplat-X` can be reproduced with the command +:code:`cd examples; bash benchmarks/basic_2dgs.sh` +within the gsplat repo (commit 48abf70). + +The evaluation of `inria-X` can be +reproduced with our forked wersion of the official implementation at +`here `_; +you need to change the :code:`--model_type 2dgs` to :code:`--model_type 2dgs-inria` in +:code:`benchmars/basic_2dgs` and run command :code:`cd examples; bash benchmarks/basic_2dgs.sh` (commit 28c928a). \ No newline at end of file diff --git a/examples/benchmarks/basic_2dgs.sh b/examples/benchmarks/basic_2dgs.sh new file mode 100755 index 000000000..04d3d8fd4 --- /dev/null +++ b/examples/benchmarks/basic_2dgs.sh @@ -0,0 +1,52 @@ +SCENE_DIR="data/360_v2" +RESULT_DIR="results/benchmark_2dgs" +SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers + +for SCENE in $SCENE_LIST; +do + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + echo "Running $SCENE" + + # train without eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 2dgs \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # run eval and render + for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + do + CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor $DATA_FACTOR \ + --model_type 2dgs \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --ckpt $CKPT + done +done + + +for SCENE in $SCENE_LIST; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val*.json; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; + do + echo $STATS + cat $STATS; + echo + done +done \ No newline at end of file diff --git a/examples/image_fitting.py b/examples/image_fitting.py index 672f7a1aa..434b7869b 100644 --- a/examples/image_fitting.py +++ b/examples/image_fitting.py @@ -2,7 +2,7 @@ import os import time from pathlib import Path -from typing import Optional +from typing import Literal, Optional import numpy as np import torch @@ -10,7 +10,7 @@ from PIL import Image from torch import Tensor, optim -from gsplat import rasterization +from gsplat import rasterization, rasterization_2dgs class SimpleTrainer: @@ -79,6 +79,7 @@ def train( iterations: int = 1000, lr: float = 0.01, save_imgs: bool = False, + model_type: Literal["3dgs", "2dgs"] = "3dgs", ): optimizer = optim.Adam( [self.rgbs, self.means, self.scales, self.opacities, self.quats], lr @@ -94,9 +95,16 @@ def train( ], device=self.device, ) + + if model_type == "3dgs": + rasterize_fnc = rasterization + elif model_type == "2dgs": + rasterize_fnc = rasterization_2dgs + for iter in range(iterations): start = time.time() - renders, _, _ = rasterization( + + renders = rasterize_fnc( self.means, self.quats / self.quats.norm(dim=-1, keepdim=True), self.scales, @@ -107,7 +115,7 @@ def train( self.W, self.H, packed=False, - ) + )[0] out_img = renders[0] torch.cuda.synchronize() times[0] += time.time() - start @@ -125,7 +133,7 @@ def train( if save_imgs: # save them as a gif with PIL frames = [Image.fromarray(frame) for frame in frames] - out_dir = os.path.join(os.getcwd(), "renders") + out_dir = os.path.join(os.getcwd(), "results") os.makedirs(out_dir, exist_ok=True) frames[0].save( f"{out_dir}/training.gif", @@ -158,6 +166,7 @@ def main( img_path: Optional[Path] = None, iterations: int = 1000, lr: float = 0.01, + model_type: Literal["3dgs", "2dgs"] = "3dgs", ) -> None: if img_path: gt_image = image_path_to_tensor(img_path) @@ -172,6 +181,7 @@ def main( iterations=iterations, lr=lr, save_imgs=save_imgs, + model_type=model_type, ) diff --git a/examples/requirements.txt b/examples/requirements.txt index 6ebe9abeb..148b61135 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -17,4 +17,5 @@ tyro>=0.8.8 Pillow tensorboard pyyaml +matplotlib git+https://github.com/rahul-goel/fused-ssim@84422e0da94c516220eb3acedb907e68809e9e01 diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 53c751575..a1bf7c052 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -15,7 +15,7 @@ import viser import yaml from datasets.colmap import Dataset, Parser -from datasets.traj import generate_interpolated_path, generate_ellipse_path_z +from datasets.traj import generate_ellipse_path_z, generate_interpolated_path from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter diff --git a/examples/simple_trainer_2dgs.py b/examples/simple_trainer_2dgs.py new file mode 100644 index 000000000..c37191f66 --- /dev/null +++ b/examples/simple_trainer_2dgs.py @@ -0,0 +1,970 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from typing import Dict, List, Literal, Optional, Tuple + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +from datasets.colmap import Dataset, Parser +from datasets.traj import generate_interpolated_path +from torch import Tensor +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from utils import ( + AppearanceOptModule, + CameraOptModule, + apply_depth_colormap, + colormap, + knn, + rgb_to_sh, + set_random_seed, +) + +from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper +from gsplat.strategy import DefaultStrategy + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt file. If provide, it will skip training and render a video + ckpt: Optional[str] = None + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "data/360_v2/garden" + # Downsample factor for the dataset + data_factor: int = 4 + # Directory to save results + result_dir: str = "results/garden" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # Initialization strategy + init_type: str = "sfm" + # Initial number of GSs. Ignored if using sfm + init_num_pts: int = 100_000 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Degree of spherical harmonics + sh_degree: int = 3 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.2 + # Far plane clipping distance + far_plane: float = 200 + + # GSs with opacity below this value will be pruned + prune_opa: float = 0.05 + # GSs with image plane gradient above this value will be split/duplicated + grow_grad2d: float = 0.0002 + # GSs with scale below this value will be duplicated. Above will be split + grow_scale3d: float = 0.01 + # GSs with scale above this value will be pruned. + prune_scale3d: float = 0.1 + + # Start refining GSs after this iteration + refine_start_iter: int = 500 + # Stop refining GSs after this iteration + refine_stop_iter: int = 15_000 + # Reset opacities every this steps + reset_every: int = 3000 + # Refine GSs every this steps + refine_every: int = 100 + + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006 + absgrad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + # Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental) + revised_opacity: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Enable normal consistency loss. (Currently for 2DGS only) + normal_loss: bool = False + # Weight for normal loss + normal_lambda: float = 5e-2 + # Iteration to start normal consistency regulerization + normal_start_iter: int = 7_000 + + # Distortion loss. (experimental) + dist_loss: bool = False + # Weight for distortion loss + dist_lambda: float = 1e-2 + # Iteration to start distortion loss regulerization + dist_start_iter: int = 3_000 + + # Model for splatting. + model_type: Literal["2dgs", "2dgs-inria"] = "2dgs" + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + self.refine_start_iter = int(self.refine_start_iter * factor) + self.refine_stop_iter = int(self.refine_stop_iter * factor) + self.reset_every = int(self.reset_every * factor) + self.refine_every = int(self.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + init_type: str = "sfm", + init_num_pts: int = 100_000, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + if init_type == "sfm": + points = torch.from_numpy(parser.points).float() + rgbs = torch.from_numpy(parser.points_rgb / 255.0).float() + elif init_type == "random": + points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1) + rgbs = torch.rand((init_num_pts, 3)) + else: + raise ValueError("Please specify a correct init_type: sfm or random") + + N = points.shape[0] + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3] + quats = torch.rand((N, 4)) # [N, 4] + opacities = torch.logit(torch.full((N,), init_opacity)) # [N,] + + params = [ + # name, value, lr + ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities", torch.nn.Parameter(opacities), 5e-2), + ] + + if feature_dim is None: + # color is SH coefficients. + colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3] + colors[:, 0, :] = rgb_to_sh(rgbs) + params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3)) + params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20)) + else: + # features will be used for appearance and view-dependent shading + features = torch.rand(N, feature_dim) # [N, feature_dim] + params.append(("features", torch.nn.Parameter(features), 2.5e-3)) + colors = torch.logit(rgbs) # [N, 3] + params.append(("colors", torch.nn.Parameter(colors), 2.5e-3)) + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(batch_size)}], + eps=1e-15 / math.sqrt(batch_size), + betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__(self, cfg: Config) -> None: + set_random_seed(42) + + self.cfg = cfg + self.device = "cuda" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=True, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + init_type=cfg.init_type, + init_num_pts=cfg.init_num_pts, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sh_degree=cfg.sh_degree, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + ) + print("Model initialized. Number of GS:", len(self.splats["means"])) + self.model_type = cfg.model_type + + if self.model_type == "2dgs": + key_for_gradient = "gradient_2dgs" + else: + key_for_gradient = "means2d" + + # Densification Strategy + self.strategy = DefaultStrategy( + verbose=True, + prune_opa=cfg.prune_opa, + grow_grad2d=cfg.grow_grad2d, + grow_scale3d=cfg.grow_scale3d, + prune_scale3d=cfg.prune_scale3d, + # refine_scale2d_stop_iter=4000, # splatfacto behavior + refine_start_iter=cfg.refine_start_iter, + refine_stop_iter=cfg.refine_stop_iter, + reset_every=cfg.reset_every, + refine_every=cfg.refine_every, + absgrad=cfg.absgrad, + revised_opacity=cfg.revised_opacity, + key_for_gradient=key_for_gradient, + ) + self.strategy.check_sanity(self.splats, self.optimizers) + self.strategy_state = self.strategy.initialize_state() + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + + self.app_optimizers = [] + if cfg.app_opt: + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to( + self.device + ) + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Dict]: + means = self.splats["means"] # [N, 3] + # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"]) # [N, 3] + opacities = torch.sigmoid(self.splats["opacities"]) # [N,] + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + self.splats["colors"] + colors = torch.sigmoid(colors) + else: + colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3] + + assert self.cfg.antialiased is False, "Antialiased is not supported for 2DGS" + + if self.model_type == "2dgs": + ( + render_colors, + render_alphas, + render_normals, + normals_from_depth, + render_distort, + render_median, + info, + ) = rasterization_2dgs( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + **kwargs, + ) + elif self.model_type == "2dgs-inria": + render_colors, render_alphas, info = rasterization_2dgs_inria_wrapper( + means=means, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.absgrad, + sparse_grad=self.cfg.sparse_grad, + **kwargs, + ) + render_colors, render_alphas = renders + render_normals = info["normals_rend"] + normals_from_depth = info["normals_surf"] + render_distort = info["render_distloss"] + render_median = render_colors[..., 3] + + return ( + render_colors, + render_alphas, + render_normals, + normals_from_depth, + render_distort, + render_median, + info, + ) + + def train(self): + cfg = self.cfg + device = self.device + + # Dump cfg. + with open(f"{cfg.result_dir}/cfg.json", "w") as f: + json.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # means has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # sh schedule + sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree) + + # forward + ( + renders, + alphas, + normals, + normals_from_depth, + render_distort, + render_median, + info, + ) = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=sh_degree_to_use, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB+D", + distloss=self.cfg.dist_loss, + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - self.ssim( + pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2) + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + + if cfg.normal_loss: + if step > cfg.normal_start_iter: + curr_normal_lambda = cfg.normal_lambda + else: + curr_normal_lambda = 0.0 + # normal consistency loss + normals = normals.squeeze(0).permute((2, 0, 1)) + normals_from_depth *= alphas.squeeze(0).detach() + if len(normals_from_depth.shape) == 4: + normals_from_depth = normals_from_depth.squeeze(0) + normals_from_depth = normals_from_depth.permute((2, 0, 1)) + normal_error = (1 - (normals * normals_from_depth).sum(dim=0))[None] + normalloss = curr_normal_lambda * normal_error.mean() + loss += normalloss + + if cfg.dist_loss: + if step > cfg.dist_start_iter: + curr_dist_lambda = cfg.dist_lambda + else: + curr_dist_lambda = 0.0 + distloss = render_distort.mean() + loss += distloss * curr_dist_lambda + + loss.backward() + + desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.dist_loss: + desc += f"dist loss={distloss.item():.6f}" + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.normal_loss: + self.writer.add_scalar("train/normalloss", normalloss.item(), step) + if cfg.dist_loss: + self.writer.add_scalar("train/distloss", distloss.item(), step) + if cfg.tb_save_image: + canvas = ( + torch.cat([pixels, colors[..., :3]], dim=2) + .detach() + .cpu() + .numpy() + ) + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + self.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers.values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # save checkpoint + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["means"]), + } + print("Step: ", step, stats) + with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f: + json.dump(stats, f) + torch.save( + { + "step": step, + "splats": self.splats.state_dict(), + }, + f"{self.ckpt_dir}/ckpt_{step}.pt", + ) + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1: + self.eval(step) + self.render_traj(step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = {"psnr": [], "ssim": [], "lpips": []} + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + ( + colors, + alphas, + normals, + normals_from_depth, + render_distort, + render_median, + _, + ) = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 3] + colors = torch.clamp(colors, 0.0, 1.0) + colors = colors[..., :3] # Take RGB channels + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + # write images + canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy() + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8) + ) + + # write median depths + render_median = (render_median - render_median.min()) / ( + render_median.max() - render_median.min() + ) + # render_median = render_median.detach().cpu().squeeze(0).unsqueeze(-1).repeat(1, 1, 3).numpy() + render_median = ( + render_median.detach().cpu().squeeze(0).repeat(1, 1, 3).numpy() + ) + + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}_median_depth_{step}.png", + (render_median * 255).astype(np.uint8), + ) + + # write normals + normals = (normals * 0.5 + 0.5).squeeze(0).cpu().numpy() + normals_output = (normals * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}_normal_{step}.png", normals_output + ) + + # write normals from depth + normals_from_depth *= alphas.squeeze(0).detach() + normals_from_depth = (normals_from_depth * 0.5 + 0.5).cpu().numpy() + normals_from_depth = (normals_from_depth - np.min(normals_from_depth)) / ( + np.max(normals_from_depth) - np.min(normals_from_depth) + ) + normals_from_depth_output = (normals_from_depth * 255).astype(np.uint8) + if len(normals_from_depth_output.shape) == 4: + normals_from_depth_output = normals_from_depth_output.squeeze(0) + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}_normals_from_depth_{step}.png", + normals_from_depth_output, + ) + + # write distortions + + render_dist = render_distort + dist_max = torch.max(render_dist) + dist_min = torch.min(render_dist) + render_dist = (render_dist - dist_min) / (dist_max - dist_min) + render_dist = ( + colormap(render_dist.cpu().numpy()[0]) + .permute((1, 2, 0)) + .numpy() + .astype(np.uint8) + ) + imageio.imwrite( + f"{self.render_dir}/val_{i:04d}_distortions_{step}.png", render_dist + ) + + pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors, pixels)) + metrics["ssim"].append(self.ssim(colors, pixels)) + metrics["lpips"].append(self.lpips(colors, pixels)) + + ellipse_time /= len(valloader) + + psnr = torch.stack(metrics["psnr"]).mean() + ssim = torch.stack(metrics["ssim"]).mean() + lpips = torch.stack(metrics["lpips"]).mean() + print( + f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} " + f"Time: {ellipse_time:.3f}s/image " + f"Number of GS: {len(self.splats['means'])}" + ) + # save stats as json + stats = { + "psnr": psnr.item(), + "ssim": ssim.item(), + "lpips": lpips.item(), + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["means"]), + } + with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"val/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds = self.parser.camtoworlds[5:-5] + camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4] + camtoworlds = np.concatenate( + [ + camtoworlds, + np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds = torch.from_numpy(camtoworlds).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"): + renders, _, _, surf_normals, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds[i : i + 1], + Ks=K[None], + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3] + depths = renders[0, ..., 3:4] # [H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + + surf_normals = (surf_normals - surf_normals.min()) / ( + surf_normals.max() - surf_normals.min() + ) + + # write images + canvas = torch.cat( + [colors, depths.repeat(1, 1, 3)], dim=1 if width > height else 1 + ) + canvas = (canvas.cpu().numpy() * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _, _, _, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=self.cfg.sh_degree, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(cfg: Config): + runner = Runner(cfg) + + if cfg.ckpt is not None: + # run eval only + ckpt = torch.load(cfg.ckpt, map_location=runner.device) + for k in runner.splats.keys(): + runner.splats[k].data = ckpt["splats"][k] + runner.eval(step=ckpt["step"]) + runner.render_traj(step=ckpt["step"]) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + main(cfg) diff --git a/examples/utils.py b/examples/utils.py index 750322e5f..b79f4244b 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -5,6 +5,8 @@ from sklearn.neighbors import NearestNeighbors from torch import Tensor import torch.nn.functional as F +import matplotlib.pyplot as plt +from matplotlib import colormaps class CameraOptModule(torch.nn.Module): @@ -152,3 +154,71 @@ def set_random_seed(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) + + +# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163 +def colormap(img, cmap="jet"): + W, H = img.shape[:2] + dpi = 300 + fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi) + im = ax.imshow(img, cmap=cmap) + ax.set_axis_off() + fig.colorbar(im, ax=ax) + fig.tight_layout() + fig.canvas.draw() + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + img = torch.from_numpy(data).float().permute(2, 0, 1) + plt.close() + return img + + +def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor: + """Convert single channel to a color img. + + Args: + img (torch.Tensor): (..., 1) float32 single channel image. + colormap (str): Colormap for img. + + Returns: + (..., 3) colored img with colors in [0, 1]. + """ + img = torch.nan_to_num(img, 0) + if colormap == "gray": + return img.repeat(1, 1, 3) + img_long = (img * 255).long() + img_long_min = torch.min(img_long) + img_long_max = torch.max(img_long) + assert img_long_min >= 0, f"the min value is {img_long_min}" + assert img_long_max <= 255, f"the max value is {img_long_max}" + return torch.tensor( + colormaps[colormap].colors, # type: ignore + device=img.device, + )[img_long[..., 0]] + + +def apply_depth_colormap( + depth: torch.Tensor, + acc: torch.Tensor = None, + near_plane: float = None, + far_plane: float = None, +) -> torch.Tensor: + """Converts a depth image to color for easier analysis. + + Args: + depth (torch.Tensor): (..., 1) float32 depth. + acc (torch.Tensor | None): (..., 1) optional accumulation mask. + near_plane: Closest depth to consider. If None, use min image value. + far_plane: Furthest depth to consider. If None, use max image value. + + Returns: + (..., 3) colored depth image with colors in [0, 1]. + """ + near_plane = near_plane or float(torch.min(depth)) + far_plane = far_plane or float(torch.max(depth)) + depth = (depth - near_plane) / (far_plane - near_plane + 1e-10) + depth = torch.clip(depth, 0.0, 1.0) + img = apply_float_colormap(depth, colormap="turbo") + if acc is not None: + img = img * acc + (1.0 - acc) + return img diff --git a/gsplat/__init__.py b/gsplat/__init__.py index 8982d8a3c..df47d1555 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -2,6 +2,7 @@ from .compression import PngCompression from .cuda._torch_impl import accumulate +from .cuda._torch_impl_2dgs import accumulate_2dgs from .cuda._wrapper import ( fully_fused_projection, isect_offset_encode, @@ -12,21 +13,26 @@ rasterize_to_pixels, spherical_harmonics, world_to_cam, + fully_fused_projection_2dgs, + rasterize_to_pixels_2dgs, + rasterize_to_indices_in_range_2dgs, ) from .rendering import ( rasterization, + rasterization_2dgs, rasterization_inria_wrapper, + rasterization_2dgs_inria_wrapper, ) from .strategy import DefaultStrategy, MCMCStrategy, Strategy from .version import __version__ - all = [ "PngCompression", "DefaultStrategy", "MCMCStrategy", "Strategy", "rasterization", + "rasterization_2dgs", "rasterization_inria_wrapper", "spherical_harmonics", "isect_offset_encode", @@ -38,5 +44,9 @@ "world_to_cam", "accumulate", "rasterize_to_indices_in_range", - "__version__", + "full_fused_projection_2dgs", + "rasterize_to_pixels_2dgs", + "rasterize_to_indices_in_range_2dgs", + "accumulate_2dgs", + "rasterization_2dgs_inria_wrapper" "__version__", ] diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py index 2585e36b3..1b06e8a59 100644 --- a/gsplat/cuda/_torch_impl.py +++ b/gsplat/cuda/_torch_impl.py @@ -6,14 +6,8 @@ from torch import Tensor -def _quat_scale_to_covar_preci( - quats: Tensor, # [N, 4], - scales: Tensor, # [N, 3], - compute_covar: bool = True, - compute_preci: bool = True, - triu: bool = False, -) -> Tuple[Optional[Tensor], Optional[Tensor]]: - """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`.""" +def _quat_to_rotmat(quats: Tensor) -> Tensor: + """Convert quaternion to rotation matrix.""" quats = F.normalize(quats, p=2, dim=-1) w, x, y, z = torch.unbind(quats, dim=-1) R = torch.stack( @@ -30,9 +24,28 @@ def _quat_scale_to_covar_preci( ], dim=-1, ) + return R.reshape(quats.shape[:-1] + (3, 3)) + + +def _quat_scale_to_matrix( + quats: Tensor, # [N, 4], + scales: Tensor, # [N, 3], +) -> Tensor: + """Convert quaternion and scale to a 3x3 matrix (R * S).""" + R = _quat_to_rotmat(quats) # (..., 3, 3) + M = R * scales[..., None, :] # (..., 3, 3) + return M + - R = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3) - # R.register_hook(lambda grad: print("grad R", grad)) +def _quat_scale_to_covar_preci( + quats: Tensor, # [N, 4], + scales: Tensor, # [N, 3], + compute_covar: bool = True, + compute_preci: bool = True, + triu: bool = False, +) -> Tuple[Optional[Tensor], Optional[Tensor]]: + """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`.""" + R = _quat_to_rotmat(quats) # (..., 3, 3) if compute_covar: M = R * scales[..., None, :] # (..., 3, 3) diff --git a/gsplat/cuda/_torch_impl_2dgs.py b/gsplat/cuda/_torch_impl_2dgs.py new file mode 100644 index 000000000..11f85546e --- /dev/null +++ b/gsplat/cuda/_torch_impl_2dgs.py @@ -0,0 +1,272 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from gsplat.cuda._torch_impl import _quat_scale_to_matrix + + +def _fully_fused_projection_2dgs( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + eps: float = 1e-6, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection_2dgs()` + + .. note:: + + This is a minimal implementation of fully fused version, which has more + arguments. Not all arguments are supported. + """ + R_cw = viewmats[:, :3, :3] # [C, 3, 3] + t_cw = viewmats[:, :3, 3] # [C, 3] + means_c = torch.einsum("cij,nj->cni", R_cw, means) + t_cw[:, None, :] # (C, N, 3) + RS_wl = _quat_scale_to_matrix(quats, scales) + RS_cl = torch.einsum("cij,njk->cnik", R_cw, RS_wl) # [C, N, 3, 3] + + # compute normals + normals = RS_cl[..., 2] # [C, N, 3] + C, N, _ = normals.shape + cos = -normals.reshape((C * N, 1, 3)) @ means_c.reshape((C * N, 3, 1)) + cos = cos.reshape((C, N, 1)) + multiplier = torch.where(cos > 0, torch.tensor(1.0), torch.tensor(-1.0)) + normals *= multiplier + + # ray transform matrix, omitting the z rotation + T_cl = torch.cat([RS_cl[..., :2], means_c[..., None]], dim=-1) # [C, N, 3, 3] + T_sl = torch.einsum("cij,cnjk->cnik", Ks[:, :3, :3], T_cl) # [C, N, 3, 3] + # in paper notation M = (WH)^T + # later h_u = M @ h_x, h_v = M @ h_y + M = torch.transpose(T_sl, -1, -2) # [C, N, 3, 3] + + # compute the AABB of gaussian + test = torch.tensor([1.0, 1.0, -1.0], device=means.device).reshape(1, 1, 3) + d = (M[..., 2] * M[..., 2] * test).sum(dim=-1, keepdim=True) # [C, N, 1] + valid = torch.abs(d) > eps + f = torch.where(valid, test / d, torch.zeros_like(test)).unsqueeze( + -1 + ) # (C, N, 3, 1) + means2d = (M[..., :2] * M[..., 2:3] * f).sum(dim=-2) # [C, N, 2] + extents = torch.sqrt( + means2d**2 - (M[..., :2] * M[..., :2] * f).sum(dim=-2) + ) # [C, N, 2] + + depths = means_c[..., 2] # [C, N] + radius = torch.ceil(3.0 * torch.max(extents, dim=-1).values) # (C, N) + + valid = valid.squeeze(-1) & (depths > near_plane) & (depths < far_plane) + radius[~valid] = 0.0 + + inside = ( + (means2d[..., 0] + radius > 0) + & (means2d[..., 0] - radius < width) + & (means2d[..., 1] + radius > 0) + & (means2d[..., 1] - radius < height) + ) + radius[~inside] = 0.0 + radii = radius.int() + return radii, means2d, depths, M, normals + + +def accumulate_2dgs( + means2d: Tensor, # [C, N, 2] + ray_transforms: Tensor, # [C, N, 3, 3] + opacities: Tensor, # [C, N] + colors: Tensor, # [C, N, channels] + normals: Tensor, # [C, N, 3] + gaussian_ids: Tensor, # [M] + pixel_ids: Tensor, # [M] + camera_ids: Tensor, # [M] + image_width: int, + image_height: int, +) -> Tuple[Tensor, Tensor, Tensor]: + """Alpha compositing for 2DGS. + + .. warning:: + This function requires the nerfacc package to be installed. Please install it using the following command pip install nerfacc. + + Args: + means2d: Gaussian means in 2D. [C, N, 2] + ray_transforms: transformation matrices that transform rays in pixel space into splat's local frame. [C, N, 3, 3] + opacities: Per-view Gaussian opacities (for example, when antialiasing is enabled, Gaussian in + each view would efficiently have different opacity). [C, N] + colors: Per-view Gaussian colors. Supports N-D features. [C, N, channels] + normals: Per-view Gaussian normals. [C, N, 3] + gaussian_ids: Collection of Gaussian indices to be rasterized. A flattened list of shape [M]. + pixel_ids: Collection of pixel indices (row-major) to be rasterized. A flattened list of shape [M]. + camera_ids: Collection of camera indices to be rasterized. A flattened list of shape [M]. + image_width: Image width. + image_height: Image height. + + Returns: + A tuple: + + **renders**: Accumulated colors. [C, image_height, image_width, channels] + + **alphas**: Accumulated opacities. [C, image_height, image_width, 1] + + **normals**: Accumulated opacities. [C, image_height, image_width, 3] + """ + + try: + from nerfacc import accumulate_along_rays, render_weight_from_alpha + except ImportError: + raise ImportError("Please install nerfacc package: pip install nerfacc") + + C, N = means2d.shape[:2] + channels = colors.shape[-1] + + pixel_ids_x = pixel_ids % image_width + 0.5 + pixel_ids_y = pixel_ids // image_width + 0.5 + pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) # [M, 2] + deltas = pixel_coords - means2d[camera_ids, gaussian_ids] # [M, 2] + + M = ray_transforms[camera_ids, gaussian_ids] # [M, 3, 3] + + h_u = -M[..., 0, :3] + M[..., 2, :3] * pixel_ids_x[..., None] # [M, 3] + h_v = -M[..., 1, :3] + M[..., 2, :3] * pixel_ids_y[..., None] # [M, 3] + tmp = torch.cross(h_u, h_v, dim=-1) + us = tmp[..., 0] / tmp[..., 2] + vs = tmp[..., 1] / tmp[..., 2] + sigmas_3d = us**2 + vs**2 # [M] + sigmas_2d = 2 * (deltas[..., 0] ** 2 + deltas[..., 1] ** 2) + sigmas = 0.5 * torch.minimum(sigmas_3d, sigmas_2d) # [M] + + alphas = torch.clamp_max( + opacities[camera_ids, gaussian_ids] * torch.exp(-sigmas), 0.999 + ) + + indices = camera_ids * image_height * image_width + pixel_ids + total_pixels = C * image_height * image_width + + weights, trans = render_weight_from_alpha( + alphas, ray_indices=indices, n_rays=total_pixels + ) + renders = accumulate_along_rays( + weights, + colors[camera_ids, gaussian_ids], + ray_indices=indices, + n_rays=total_pixels, + ).reshape(C, image_height, image_width, channels) + alphas = accumulate_along_rays( + weights, None, ray_indices=indices, n_rays=total_pixels + ).reshape(C, image_height, image_width, 1) + renders_normal = accumulate_along_rays( + weights, + normals[camera_ids, gaussian_ids], + ray_indices=indices, + n_rays=total_pixels, + ).reshape(C, image_height, image_width, 3) + + return renders, alphas, renders_normal + + +def _rasterize_to_pixels_2dgs( + means2d: Tensor, # [C, N, 2] + ray_transforms: Tensor, # [C, N, 3, 3] + colors: Tensor, # [C, N, channels] + normals: Tensor, # [C, N, 3] + opacities: Tensor, # [C, N] + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, # [C, tile_height, tile_width] + flatten_ids: Tensor, # [n_isects] + backgrounds: Optional[Tensor] = None, # [C, channels] + batch_per_iter: int = 100, +): + """Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels_2dgs()`. + + This function rasterizes 2D Gaussians to pixels in a Pytorch-friendly way. It + iteratively accumulates the renderings within each batch of Gaussians. The + interations are controlled by `batch_per_iter`. + + .. note:: + This is a minimal implementation of the fully fused version, which has more + arguments. Not all arguments are supported. + + .. note:: + + This function relies on Pytorch's autograd for the backpropagation. It is much slower + than our fully fused rasterization implementation and comsumes much more GPU memory. + But it could serve as a playground for new ideas or debugging, as no backward + implementation is needed. + + .. warning:: + + This function requires the `nerfacc` package to be installed. Please install it + using the following command `pip install nerfacc`. + """ + from ._wrapper import rasterize_to_indices_in_range_2dgs + + C, N = means2d.shape[:2] + n_isects = len(flatten_ids) + device = means2d.device + + render_colors = torch.zeros( + (C, image_height, image_width, colors.shape[-1]), device=device + ) + render_alphas = torch.zeros((C, image_height, image_width, 1), device=device) + render_normals = torch.zeros((C, image_height, image_width, 3), device=device) + + # Split Gaussians into batches and iteratively accumulate the renderings + block_size = tile_size * tile_size + isect_offsets_fl = torch.cat( + [isect_offsets.flatten(), torch.tensor([n_isects], device=device)] + ) + max_range = (isect_offsets_fl[1:] - isect_offsets_fl[:-1]).max().item() + num_batches = (max_range + block_size - 1) // block_size + for step in range(0, num_batches, batch_per_iter): + transmittances = 1.0 - render_alphas[..., 0] + + # Find the M intersections between pixels and gaussians. + # Each intersection corresponds to a tuple (gs_id, pixel_id, camera_id) + gs_ids, pixel_ids, camera_ids = rasterize_to_indices_in_range_2dgs( + step, + step + batch_per_iter, + transmittances, + means2d, + ray_transforms, + opacities, + image_width, + image_height, + tile_size, + isect_offsets, + flatten_ids, + ) # [M], [M] + if len(gs_ids) == 0: + break + + # Accumulate the renderings within this batch of Gaussians. + renders_step, accs_step, renders_normal_step = accumulate_2dgs( + means2d, + ray_transforms, + opacities, + colors, + normals, + gs_ids, + pixel_ids, + camera_ids, + image_width, + image_height, + ) + render_colors = render_colors + renders_step * transmittances[..., None] + render_alphas = render_alphas + accs_step * transmittances[..., None] + render_normals = ( + render_normals + renders_normal_step * transmittances[..., None] + ) + + render_alphas = render_alphas + if backgrounds is not None: + render_colors = render_colors + backgrounds[:, None, None, :] * ( + 1.0 - render_alphas + ) + + return render_colors, render_alphas, render_normals diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py index ded7d5989..79dcfd29f 100644 --- a/gsplat/cuda/_wrapper.py +++ b/gsplat/cuda/_wrapper.py @@ -1207,3 +1207,743 @@ def backward(ctx, v_colors: Tensor): if not compute_v_dirs: v_dirs = None return None, v_dirs, v_coeffs, None + + +###### 2DGS ###### +def fully_fused_projection_2dgs( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, + viewmats: Tensor, + Ks: Tensor, + width: int, + height: int, + eps2d: float = 0.3, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + packed: bool = False, + sparse_grad: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """Prepare Gaussians for rasterization + + This function prepares ray-splat intersection matrices, computes + per splat bounding box and 2D means in image space. + + Args: + means: Gaussian means. [N, 3] + quats: Quaternions (No need to be normalized). [N, 4]. + scales: Scales. [N, 3]. + viewmats: Camera-to-world matrices. [C, 4, 4] + Ks: Camera intrinsics. [C, 3, 3] + width: Image width. + height: Image height. + near_plane: Near plane distance. Default: 0.01. + far_plane: Far plane distance. Default: 200. + radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0. + packed: If True, the output tensors will be packed into a flattened tensor. Default: False. + sparse_grad (Experimental): This is only effective when `packed` is True. If True, during backward the gradients + of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False. + + Returns: + A tuple: + + If `packed` is True: + + - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz]. + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz]. + - **means**. Projected Gaussian means in 2D. [nnz, 2] + - **depths**. The z-depth of the projected Gaussians. [nnz] + - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates (WH)^T in equation (9) in paper [nnz, 3, 3] + - **normals**. The normals in camera spaces. [nnz, 3] + + If `packed` is False: + + - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N]. + - **means**. Projected Gaussian means in 2D. [C, N, 2] + - **depths**. The z-depth of the projected Gaussians. [C, N] + - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. + - **normals**. The normals in camera spaces. [C, N, 3] + + """ + C = viewmats.size(0) + N = means.size(0) + assert means.size() == (N, 3), means.size() + assert viewmats.size() == (C, 4, 4), viewmats.size() + assert Ks.size() == (C, 3, 3), Ks.size() + means = means.contiguous() + assert quats is not None, "quats is required" + assert scales is not None, "scales is required" + assert quats.size() == (N, 4), quats.size() + assert scales.size() == (N, 3), scales.size() + quats = quats.contiguous() + scales = scales.contiguous() + if sparse_grad: + assert packed, "sparse_grad is only supported when packed is True" + + viewmats = viewmats.contiguous() + Ks = Ks.contiguous() + if packed: + return _FullyFusedProjectionPacked2DGS.apply( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + near_plane, + far_plane, + radius_clip, + sparse_grad, + ) + else: + return _FullyFusedProjection2DGS.apply( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d, + near_plane, + far_plane, + radius_clip, + ) + + +class _FullyFusedProjection2DGS(torch.autograd.Function): + """Projects Gaussians to 2D.""" + + @staticmethod + def forward( + ctx, + means: Tensor, + quats: Tensor, + scales: Tensor, + viewmats: Tensor, + Ks: Tensor, + width: int, + height: int, + eps2d: float, + near_plane: float, + far_plane: float, + radius_clip: float, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + radii, means2d, depths, ray_transforms, normals = _make_lazy_cuda_func( + "fully_fused_projection_fwd_2dgs" + )( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d, + near_plane, + far_plane, + radius_clip, + ) + ctx.save_for_backward( + means, + quats, + scales, + viewmats, + Ks, + radii, + ray_transforms, + normals, + ) + ctx.width = width + ctx.height = height + ctx.eps2d = eps2d + + return radii, means2d, depths, ray_transforms, normals + + @staticmethod + def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals): + ( + means, + quats, + scales, + viewmats, + Ks, + radii, + ray_transforms, + normals, + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + eps2d = ctx.eps2d + v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( + "fully_fused_projection_bwd_2dgs" + )( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + radii, + ray_transforms, + v_means2d.contiguous(), + v_depths.contiguous(), + v_normals.contiguous(), + v_ray_transforms.contiguous(), + ctx.needs_input_grad[3], # viewmats_requires_grad + ) + if not ctx.needs_input_grad[0]: + v_means = None + if not ctx.needs_input_grad[1]: + v_quats = None + if not ctx.needs_input_grad[2]: + v_scales = None + if not ctx.needs_input_grad[3]: + v_viewmats = None + + return ( + v_means, + v_quats, + v_scales, + v_viewmats, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class _FullyFusedProjectionPacked2DGS(torch.autograd.Function): + """Projects Gaussians to 2D. Return packed tensors.""" + + @staticmethod + def forward( + ctx, + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float, + far_plane: float, + radius_clip: float, + sparse_grad: bool, + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + ( + indptr, + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + ray_transforms, + normals, + ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd_2dgs")( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + near_plane, + far_plane, + radius_clip, + ) + ctx.save_for_backward( + camera_ids, + gaussian_ids, + means, + quats, + scales, + viewmats, + Ks, + ray_transforms, + ) + ctx.width = width + ctx.height = height + ctx.sparse_grad = sparse_grad + + return camera_ids, gaussian_ids, radii, means2d, depths, ray_transforms, normals + + @staticmethod + def backward( + ctx, + v_camera_ids, + v_gaussian_ids, + v_radii, + v_means2d, + v_depths, + v_ray_transforms, + v_normals, + ): + ( + camera_ids, + gaussian_ids, + means, + quats, + scales, + viewmats, + Ks, + ray_transforms, + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + sparse_grad = ctx.sparse_grad + + v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func( + "fully_fused_projection_packed_bwd_2dgs" + )( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + camera_ids, + gaussian_ids, + ray_transforms, + v_means2d.contiguous(), + v_depths.contiguous(), + v_ray_transforms.contiguous(), + v_normals.contiguous(), + ctx.needs_input_grad[4], # viewmats_requires_grad + sparse_grad, + ) + + if not ctx.needs_input_grad[0]: + v_means = None + else: + if sparse_grad: + # TODO: gaussian_ids is duplicated so not ideal. + # An idea is to directly set the attribute (e.g., .sparse_grad) of + # the tensor but this requires the tensor to be leaf node only. And + # a customized optimizer would be needed in this case. + v_means = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=v_means, # [nnz, 3] + size=means.size(), # [N, 3] + is_coalesced=len(viewmats) == 1, + ) + if not ctx.needs_input_grad[1]: + v_quats = None + else: + if sparse_grad: + v_quats = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=v_quats, # [nnz, 4] + size=quats.size(), # [N, 4] + is_coalesced=len(viewmats) == 1, + ) + if not ctx.needs_input_grad[2]: + v_scales = None + else: + if sparse_grad: + v_scales = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=v_scales, # [nnz, 3] + size=scales.size(), # [N, 3] + is_coalesced=len(viewmats) == 1, + ) + if not ctx.needs_input_grad[4]: + v_viewmats = None + + return ( + v_means, + v_quats, + v_scales, + v_viewmats, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def rasterize_to_pixels_2dgs( + means2d: Tensor, + ray_transforms: Tensor, + colors: Tensor, + opacities: Tensor, + normals: Tensor, + densify: Tensor, + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, + flatten_ids: Tensor, + backgrounds: Optional[Tensor] = None, + masks: Optional[Tensor] = None, + packed: bool = False, + absgrad: bool = False, + distloss: bool = False, +) -> Tuple[Tensor, Tensor]: + """Rasterize Gaussians to pixels. + + Args: + means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True. + ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] if packed is False, [nnz, channels] if packed is True. + colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True. + opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True. + normals: The normals in camera space. [C, N, 3] if packed is False, [nnz, 3] if packed is True. + densify: Dummy variable to keep track of gradient for densification. [C, N, 2] if packed, [nnz, 3] if packed is True. + tile_size: Tile size. + isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] + flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects] + backgrounds: Background colors. [C, channels]. Default: None. + masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None. + packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False. + absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False. + + Returns: + A tuple: + + - **Rendered colors**. [C, image_height, image_width, channels] + - **Rendered alphas**. [C, image_height, image_width, 1] + - **Rendered normals**. [C, image_height, image_width, 3] + - **Rendered distortion**. [C, image_height, image_width, 1] + - **Rendered median depth**.[C, image_height, image_width, 1] + + + """ + C = isect_offsets.size(0) + device = means2d.device + if packed: + nnz = means2d.size(0) + assert means2d.shape == (nnz, 2), means2d.shape + assert ray_transforms.shape == (nnz, 3, 3), ray_transforms.shape + assert colors.shape[0] == nnz, colors.shape + assert opacities.shape == (nnz,), opacities.shape + else: + N = means2d.size(1) + assert means2d.shape == (C, N, 2), means2d.shape + assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape + assert colors.shape[:2] == (C, N), colors.shape + assert opacities.shape == (C, N), opacities.shape + if backgrounds is not None: + assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape + backgrounds = backgrounds.contiguous() + + # Pad the channels to the nearest supported number if necessary + channels = colors.shape[-1] + if channels > 512 or channels == 0: + # TODO: maybe worth to support zero channels? + raise ValueError(f"Unsupported number of color channels: {channels}") + if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512): + padded_channels = (1 << (channels - 1).bit_length()) - channels + colors = torch.cat( + [colors, torch.empty(*colors.shape[:-1], padded_channels, device=device)], + dim=-1, + ) + if backgrounds is not None: + backgrounds = torch.cat( + [ + backgrounds, + torch.empty( + *backgrounds.shape[:-1], padded_channels, device=device + ), + ], + dim=-1, + ) + else: + padded_channels = 0 + tile_height, tile_width = isect_offsets.shape[1:3] + assert ( + tile_height * tile_size >= image_height + ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" + assert ( + tile_width * tile_size >= image_width + ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" + + ( + render_colors, + render_alphas, + render_normals, + render_distort, + render_median, + ) = _RasterizeToPixels2DGS.apply( + means2d.contiguous(), + ray_transforms.contiguous(), + colors.contiguous(), + opacities.contiguous(), + normals.contiguous(), + densify.contiguous(), + backgrounds, + masks, + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + absgrad, + distloss, + ) + + if padded_channels > 0: + render_colors = render_colors[..., :-padded_channels] + + return render_colors, render_alphas, render_normals, render_distort, render_median + + +@torch.no_grad() +def rasterize_to_indices_in_range_2dgs( + range_start: int, + range_end: int, + transmittances: Tensor, + means2d: Tensor, + ray_transforms: Tensor, + opacities: Tensor, + image_width: int, + image_height: int, + tile_size: int, + isect_offsets: Tensor, + flatten_ids: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """Rasterizes a batch of Gaussians to images but only returns the indices. + + .. note:: + + This function supports iterative rasterization, in which each call of this function + will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`. + If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really + large number, e.g, 1e10. + + Args: + range_start: The start batch of Gaussians to be rasterized (inclusive). + range_end: The end batch of Gaussians to be rasterized (exclusive). + transmittances: Currently transmittances. [C, image_height, image_width] + means2d: Projected Gaussian means. [C, N, 2] + ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] + opacities: Gaussian opacities that support per-view values. [C, N] + image_width: Image width. + image_height: Image height. + tile_size: Tile size. + isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width] + flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects] + + Returns: + A tuple: + + - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M]. + - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M]. + - **Camera ids**. Camera indices. A flattened list of shape [M]. + """ + + C, N, _ = means2d.shape + assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape + assert opacities.shape == (C, N), opacities.shape + assert isect_offsets.shape[0] == C, isect_offsets.shape + + tile_height, tile_width = isect_offsets.shape[1:3] + assert ( + tile_height * tile_size >= image_height + ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}" + assert ( + tile_width * tile_size >= image_width + ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}" + + out_gauss_ids, out_indices = _make_lazy_cuda_func( + "rasterize_to_indices_in_range_2dgs" + )( + range_start, + range_end, + transmittances.contiguous(), + means2d.contiguous(), + ray_transforms.contiguous(), + opacities.contiguous(), + image_width, + image_height, + tile_size, + isect_offsets.contiguous(), + flatten_ids.contiguous(), + ) + out_pixel_ids = out_indices % (image_width * image_height) + out_camera_ids = out_indices // (image_width * image_height) + return out_gauss_ids, out_pixel_ids, out_camera_ids + + +class _RasterizeToPixels2DGS(torch.autograd.Function): + """Rasterize gaussians 2DGS""" + + @staticmethod + def forward( + ctx, + means2d: Tensor, + ray_transforms: Tensor, + colors: Tensor, + opacities: Tensor, + normals: Tensor, + densify: Tensor, + backgrounds: Tensor, + masks: Tensor, + width: int, + height: int, + tile_size: int, + isect_offsets: Tensor, + flatten_ids: Tensor, + absgrad: bool, + distloss: bool, + ) -> Tuple[Tensor, Tensor]: + ( + render_colors, + render_alphas, + render_normals, + render_distort, + render_median, + last_ids, + median_ids, + ) = _make_lazy_cuda_func("rasterize_to_pixels_fwd_2dgs")( + means2d, + ray_transforms, + colors, + opacities, + normals, + backgrounds, + masks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + ) + + ctx.save_for_backward( + means2d, + ray_transforms, + colors, + opacities, + normals, + densify, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_colors, + render_alphas, + last_ids, + median_ids, + ) + ctx.width = width + ctx.height = height + ctx.tile_size = tile_size + ctx.absgrad = absgrad + ctx.distloss = distloss + + # doubel to float + render_alphas = render_alphas.float() + return ( + render_colors, + render_alphas, + render_normals, + render_distort, + render_median, + ) + + @staticmethod + def backward( + ctx, + v_render_colors: Tensor, + v_render_alphas: Tensor, + v_render_normals: Tensor, + v_render_distort: Tensor, + v_render_median: Tensor, + ): + + ( + means2d, + ray_transforms, + colors, + opacities, + normals, + densify, + backgrounds, + masks, + isect_offsets, + flatten_ids, + render_colors, + render_alphas, + last_ids, + median_ids, + ) = ctx.saved_tensors + width = ctx.width + height = ctx.height + tile_size = ctx.tile_size + absgrad = ctx.absgrad + + ( + v_means2d_abs, + v_means2d, + v_ray_transforms, + v_colors, + v_opacities, + v_normals, + v_densify, + ) = _make_lazy_cuda_func("rasterize_to_pixels_bwd_2dgs")( + means2d, + ray_transforms, + colors, + opacities, + normals, + densify, + backgrounds, + masks, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + render_colors, + render_alphas, + last_ids, + median_ids, + v_render_colors.contiguous(), + v_render_alphas.contiguous(), + v_render_normals.contiguous(), + v_render_distort.contiguous(), + v_render_median.contiguous(), + absgrad, + ) + torch.cuda.synchronize() + if absgrad: + means2d.absgrad = v_means2d_abs + + if ctx.needs_input_grad[6]: + v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum( + dim=(1, 2) + ) + else: + v_backgrounds = None + + return ( + v_means2d, + v_ray_transforms, + v_colors, + v_opacities, + v_normals, + v_densify, + v_backgrounds, + None, + None, + None, + None, + None, + None, + None, + None, + ) diff --git a/gsplat/cuda/csrc/CMakeLists.txt b/gsplat/cuda/csrc/CMakeLists.txt index a315d9c91..0fbf46ba9 100644 --- a/gsplat/cuda/csrc/CMakeLists.txt +++ b/gsplat/cuda/csrc/CMakeLists.txt @@ -63,4 +63,4 @@ install(TARGETS gsplat LIBRARY DESTINATION lib ) -message(STATUS "CMake configuration done!") +message(STATUS "CMake configuration done!") \ No newline at end of file diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h index 004709a4d..9e2f8328c 100644 --- a/gsplat/cuda/csrc/bindings.h +++ b/gsplat/cuda/csrc/bindings.h @@ -309,6 +309,178 @@ std::tuple compute_relocation_tensor( const int n_max ); +//====== 2DGS ======// +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip +); + +std::tuple +fully_fused_projection_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &radii, // [C, N] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] + const torch::Tensor &v_ray_transforms, // [C, N, 3, 3] + const bool viewmats_requires_grad +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_fwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_bwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3], + const torch::Tensor &densify, + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // ray_crossions + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor + &render_colors, // [C, image_height, image_width, COLOR_DIM] + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_normals, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_distort, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median, // [C, image_height, image_width, 1] + // options + bool absgrad +); + +std::tuple +rasterize_to_indices_in_range_2dgs_tensor( + const uint32_t range_start, + const uint32_t range_end, // iteration steps + const torch::Tensor transmittances, // [C, image_height, image_width] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + const torch::Tensor &opacities, // [C, N] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +); + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 3] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float near_plane, + const float far_plane, + const float radius_clip +); + +std::tuple +fully_fused_projection_packed_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &camera_ids, // [nnz] + const torch::Tensor &gaussian_ids, // [nnz] + const torch::Tensor &ray_transforms, // [nnz, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [nnz, 2] + const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_normals, // [nnz, 3] + const torch::Tensor &v_ray_transforms, // [nnz, 3, 3] + const bool viewmats_requires_grad, + const bool sparse_grad +); + } // namespace gsplat -#endif // GSPLAT_CUDA_BINDINGS_H \ No newline at end of file +#endif // GSPLAT_CUDA_BINDINGS_H diff --git a/gsplat/cuda/csrc/compute_sh_bwd.cu b/gsplat/cuda/csrc/compute_sh_bwd.cu index 4a4fda3f6..5b22cbdbe 100644 --- a/gsplat/cuda/csrc/compute_sh_bwd.cu +++ b/gsplat/cuda/csrc/compute_sh_bwd.cu @@ -93,4 +93,4 @@ std::tuple compute_sh_bwd_tensor( return std::make_tuple(v_coeffs, v_dirs); // [..., K, 3], [..., 3] } -} // namespace gsplat \ No newline at end of file +} // namespace gsplat diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp index 57248ac81..0a4a67aac 100644 --- a/gsplat/cuda/csrc/ext.cpp +++ b/gsplat/cuda/csrc/ext.cpp @@ -48,4 +48,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ); m.def("compute_relocation", &gsplat::compute_relocation_tensor); -} \ No newline at end of file + + // 2DGS + m.def( + "fully_fused_projection_fwd_2dgs", + &gsplat::fully_fused_projection_fwd_2dgs_tensor + ); + m.def( + "fully_fused_projection_bwd_2dgs", + &gsplat::fully_fused_projection_bwd_2dgs_tensor + ); + + m.def( + "fully_fused_projection_packed_fwd_2dgs", + &gsplat::fully_fused_projection_packed_fwd_2dgs_tensor + ); + m.def( + "fully_fused_projection_packed_bwd_2dgs", + &gsplat::fully_fused_projection_packed_bwd_2dgs_tensor + ); + + m.def( + "rasterize_to_pixels_fwd_2dgs", + &gsplat::rasterize_to_pixels_fwd_2dgs_tensor + ); + m.def( + "rasterize_to_pixels_bwd_2dgs", + &gsplat::rasterize_to_pixels_bwd_2dgs_tensor + ); + + m.def( + "rasterize_to_indices_in_range_2dgs", + &gsplat::rasterize_to_indices_in_range_2dgs_tensor + ); +} diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu new file mode 100644 index 000000000..ec7eb1126 --- /dev/null +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu @@ -0,0 +1,224 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Projection of Gaussians (Batched) Backward Pass + ****************************************************************************/ +template +__global__ void fully_fused_projection_bwd_2dgs_kernel( + // fwd inputs + const uint32_t C, + const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + // fwd outputs + const int32_t *__restrict__ radii, // [C, N] + const T *__restrict__ ray_transforms, // [C, N, 3, 3] + // grad outputs + const T *__restrict__ v_means2d, // [C, N, 2] + const T *__restrict__ v_depths, // [C, N] + const T *__restrict__ v_normals, // [C, N, 3] + // grad inputs + T *__restrict__ v_ray_transforms, // [C, N, 3, 3] + T *__restrict__ v_means, // [N, 3] + T *__restrict__ v_quats, // [N, 4] + T *__restrict__ v_scales, // [N, 3] + T *__restrict__ v_viewmats // [C, 4, 4] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N || radii[idx] <= 0) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + ray_transforms += idx * 9; + + v_means2d += idx * 2; + v_depths += idx; + v_normals += idx * 3; + v_ray_transforms += idx * 9; + + // transform Gaussian to camera space + mat3 R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + + vec4 quat = glm::make_vec4(quats + gid * 4); + vec2 scale = glm::make_vec2(scales + gid * 3); + + mat3 P = mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0); + + mat3 _v_ray_transforms = mat3( + v_ray_transforms[0], + v_ray_transforms[1], + v_ray_transforms[2], + v_ray_transforms[3], + v_ray_transforms[4], + v_ray_transforms[5], + v_ray_transforms[6], + v_ray_transforms[7], + v_ray_transforms[8] + ); + + _v_ray_transforms[2][2] += v_depths[0]; + + vec3 v_normal = glm::make_vec3(v_normals); + + vec3 v_mean(0.f); + vec2 v_scale(0.f); + vec4 v_quat(0.f); + compute_ray_transforms_aabb_vjp( + ray_transforms, + v_means2d, + v_normal, + R, + P, + t, + mean_c, + quat, + scale, + _v_ray_transforms, + v_quat, + v_scale, + v_mean + ); + + // #if __CUDA_ARCH__ >= 700 + // write out results with warp-level reduction + auto warp = cg::tiled_partition<32>(cg::this_thread_block()); + auto warp_group_g = cg::labeled_partition(warp, gid); + if (v_means != nullptr) { + warpSum(v_mean, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_means += gid * 3; + GSPLAT_PRAGMA_UNROLL + for (uint32_t i = 0; i < 3; i++) { + gpuAtomicAdd(v_means + i, v_mean[i]); + } + } + } + + // Directly output gradients w.r.t. the quaternion and scale + warpSum(v_quat, warp_group_g); + warpSum(v_scale, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_quats += gid * 4; + v_scales += gid * 3; + gpuAtomicAdd(v_quats, v_quat[0]); + gpuAtomicAdd(v_quats + 1, v_quat[1]); + gpuAtomicAdd(v_quats + 2, v_quat[2]); + gpuAtomicAdd(v_quats + 3, v_quat[3]); + gpuAtomicAdd(v_scales, v_scale[0]); + gpuAtomicAdd(v_scales + 1, v_scale[1]); + } +} + +std::tuple +fully_fused_projection_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 2] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &radii, // [C, N] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [C, N, 2] + const torch::Tensor &v_depths, // [C, N] + const torch::Tensor &v_normals, // [C, N, 3] + const torch::Tensor &v_ray_transforms, // [C, N, 3, 3] + const bool viewmats_requires_grad +) { + GSPLAT_DEVICE_GUARD(means); + GSPLAT_CHECK_INPUT(means); + GSPLAT_CHECK_INPUT(quats); + GSPLAT_CHECK_INPUT(scales); + GSPLAT_CHECK_INPUT(viewmats); + GSPLAT_CHECK_INPUT(Ks); + GSPLAT_CHECK_INPUT(radii); + GSPLAT_CHECK_INPUT(ray_transforms); + GSPLAT_CHECK_INPUT(v_means2d); + GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); + GSPLAT_CHECK_INPUT(v_ray_transforms); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor v_means = torch::zeros_like(means); + torch::Tensor v_quats = torch::zeros_like(quats); + torch::Tensor v_scales = torch::zeros_like(scales); + torch::Tensor v_viewmats; + if (viewmats_requires_grad) { + v_viewmats = torch::zeros_like(viewmats); + } + if (C && N) { + fully_fused_projection_bwd_2dgs_kernel + <<<(C * N + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS, + GSPLAT_N_THREADS, + 0, + stream>>>( + C, + N, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + radii.data_ptr(), + ray_transforms.data_ptr(), + v_means2d.data_ptr(), + v_depths.data_ptr(), + v_normals.data_ptr(), + v_ray_transforms.data_ptr(), + v_means.data_ptr(), + v_quats.data_ptr(), + v_scales.data_ptr(), + viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr + ); + } + return std::make_tuple(v_means, v_quats, v_scales, v_viewmats); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu new file mode 100644 index 000000000..08726d0f5 --- /dev/null +++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu @@ -0,0 +1,210 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Projection of Gaussians (Single Batch) Forward Pass 2DGS + ****************************************************************************/ + +template +__global__ void fully_fused_projection_fwd_2dgs_kernel( + const uint32_t C, + const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + const T near_plane, + const T far_plane, + const T radius_clip, + // outputs + int32_t *__restrict__ radii, // [C, N] + T *__restrict__ means2d, // [C, N, 2] + T *__restrict__ depths, // [C, N] + T *__restrict__ ray_transforms, // [C, N, 3, 3] + T *__restrict__ normals // [C, N, 3] +) { + // parallelize over C * N. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= C * N) { + return; + } + const uint32_t cid = idx / N; // camera id + const uint32_t gid = idx % N; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + // glm is column-major but input is row-major + mat3 R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // transform Gaussian center to camera space + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + if (mean_c.z < near_plane || mean_c.z > far_plane) { + radii[idx] = 0; + return; + } + + // build ray transformation matrix and transform from world space to camera + // space + quats += gid * 4; + scales += gid * 3; + + mat3 RS_camera = + R * quat_to_rotmat(glm::make_vec4(quats)) * + mat3(scales[0], 0.0, 0.0, 0.0, scales[1], 0.0, 0.0, 0.0, 1.0); + + mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c); + + mat3 world_2_pix = + mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0); + mat3 M = glm::transpose(WH) * world_2_pix; + + // compute AABB + const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]); + const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]); + const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]); + + const vec3 temp_point = vec3(1.0f, 1.0f, -1.0f); + const T distance = sum(temp_point * M2 * M2); + + if (distance == 0.0f) + return; + + const vec3 f = (1 / distance) * temp_point; + const vec2 mean2d = vec2(sum(f * M0 * M2), sum(f * M1 * M2)); + + const vec2 temp = {sum(f * M0 * M0), sum(f * M1 * M1)}; + const vec2 half_extend = mean2d * mean2d - temp; + const T radius = + ceil(3.f * sqrt(max(1e-4, max(half_extend.x, half_extend.y)))); + + if (radius <= radius_clip) { + radii[idx] = 0; + return; + } + + // mask out gaussians outside the image region + if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || + mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + radii[idx] = 0; + return; + } + + // normals dual visible + vec3 normal = RS_camera[2]; + T multipler = glm::dot(-normal, mean_c) > 0 ? 1 : -1; + normal *= multipler; + + // write to outputs + radii[idx] = (int32_t)radius; + means2d[idx * 2] = mean2d.x; + means2d[idx * 2 + 1] = mean2d.y; + depths[idx] = mean_c.z; + ray_transforms[idx * 9] = M0.x; + ray_transforms[idx * 9 + 1] = M0.y; + ray_transforms[idx * 9 + 2] = M0.z; + ray_transforms[idx * 9 + 3] = M1.x; + ray_transforms[idx * 9 + 4] = M1.y; + ray_transforms[idx * 9 + 5] = M1.z; + ray_transforms[idx * 9 + 6] = M2.x; + ray_transforms[idx * 9 + 7] = M2.y; + ray_transforms[idx * 9 + 8] = M2.z; + normals[idx * 3] = normal.x; + normals[idx * 3 + 1] = normal.y; + normals[idx * 3 + 2] = normal.z; +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float eps2d, + const float near_plane, + const float far_plane, + const float radius_clip +) { + GSPLAT_DEVICE_GUARD(means); + GSPLAT_CHECK_INPUT(means); + GSPLAT_CHECK_INPUT(quats); + GSPLAT_CHECK_INPUT(scales); + GSPLAT_CHECK_INPUT(viewmats); + GSPLAT_CHECK_INPUT(Ks); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor radii = + torch::empty({C, N}, means.options().dtype(torch::kInt32)); + torch::Tensor means2d = torch::empty({C, N, 2}, means.options()); + torch::Tensor depths = torch::empty({C, N}, means.options()); + torch::Tensor ray_transforms = torch::empty({C, N, 3, 3}, means.options()); + torch::Tensor normals = torch::empty({C, N, 3}, means.options()); + + if (C && N) { + fully_fused_projection_fwd_2dgs_kernel + <<<(C * N + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS, + GSPLAT_N_THREADS, + 0, + stream>>>( + C, + N, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + near_plane, + far_plane, + radius_clip, + radii.data_ptr(), + means2d.data_ptr(), + depths.data_ptr(), + ray_transforms.data_ptr(), + normals.data_ptr() + ); + } + return std::make_tuple(radii, means2d, depths, ray_transforms, normals); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu new file mode 100644 index 000000000..564eb70bd --- /dev/null +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu @@ -0,0 +1,270 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Projection of Gaussians (Batched) Backward Pass 2DGS + ****************************************************************************/ + +template +__global__ void fully_fused_projection_packed_bwd_2dgs_kernel( + // fwd inputs + const uint32_t C, + const uint32_t N, + const uint32_t nnz, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + // fwd outputs + const int64_t *__restrict__ camera_ids, // [nnz] + const int64_t *__restrict__ gaussian_ids, // [nnz] + const T *__restrict__ ray_transforms, // [nnz, 3] + // grad outputs + const T *__restrict__ v_means2d, // [nnz, 2] + const T *__restrict__ v_depths, // [nnz] + const T *__restrict__ v_normals, // [nnz, 3] + const bool sparse_grad, // whether the outputs are in COO format [nnz, ...] + // grad inputs + T *__restrict__ v_ray_transforms, + T *__restrict__ v_means, // [N, 3] or [nnz, 3] + T *__restrict__ v_quats, // [N, 4] or [nnz, 4] Optional + T *__restrict__ v_scales, // [N, 3] or [nnz, 3] Optional + T *__restrict__ v_viewmats // [C, 4, 4] Optional +) { + // parallelize over nnz. + uint32_t idx = cg::this_grid().thread_rank(); + if (idx >= nnz) { + return; + } + const int64_t cid = camera_ids[idx]; // camera id + const int64_t gid = gaussian_ids[idx]; // gaussian id + + // shift pointers to the current camera and gaussian + means += gid * 3; + viewmats += cid * 16; + Ks += cid * 9; + + ray_transforms += idx * 9; + + v_means2d += idx * 2; + v_normals += idx * 3; + v_depths += idx; + v_ray_transforms += idx * 9; + + // transform Gaussian to camera space + mat3 R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + vec3 mean_c; + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + + vec4 quat = glm::make_vec4(quats + gid * 4); + vec2 scale = glm::make_vec2(scales + gid * 3); + mat3 P = mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0); + + mat3 _v_ray_transforms = mat3( + v_ray_transforms[0], + v_ray_transforms[1], + v_ray_transforms[2], + v_ray_transforms[3], + v_ray_transforms[4], + v_ray_transforms[5], + v_ray_transforms[6], + v_ray_transforms[7], + v_ray_transforms[8] + ); + + _v_ray_transforms[2][2] += v_depths[0]; + + vec3 v_normal = glm::make_vec3(v_normals); + + vec3 v_mean(0.f); + vec2 v_scale(0.f); + vec4 v_quat(0.f); + compute_ray_transforms_aabb_vjp( + ray_transforms, + v_means2d, + v_normal, + R, + P, + t, + mean_c, + quat, + scale, + _v_ray_transforms, + v_quat, + v_scale, + v_mean + ); + + auto warp = cg::tiled_partition<32>(cg::this_thread_block()); + if (sparse_grad) { + // write out results with sparse layout + if (v_means != nullptr) { + v_means += idx * 3; + GSPLAT_PRAGMA_UNROLL + for (uint32_t i = 0; i < 3; i++) { + v_means[i] = v_mean[i]; + } + } + v_quats += idx * 4; + v_scales += idx * 3; + v_quats[0] = v_quat[0]; + v_quats[1] = v_quat[1]; + v_quats[2] = v_quat[2]; + v_quats[3] = v_quat[3]; + v_scales[0] = v_scale[0]; + v_scales[1] = v_scale[1]; + } else { + // write out results with dense layout + // #if __CUDA_ARCH__ >= 700 + // write out results with warp-level reduction + auto warp_group_g = cg::labeled_partition(warp, gid); + if (v_means != nullptr) { + warpSum(v_mean, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_means += gid * 3; + GSPLAT_PRAGMA_UNROLL + for (uint32_t i = 0; i < 3; i++) { + gpuAtomicAdd(v_means + i, v_mean[i]); + } + } + } + // Directly output gradients w.r.t. the quaternion and scale + warpSum(v_quat, warp_group_g); + warpSum(v_scale, warp_group_g); + if (warp_group_g.thread_rank() == 0) { + v_quats += gid * 4; + v_scales += gid * 3; + gpuAtomicAdd(v_quats, v_quat[0]); + gpuAtomicAdd(v_quats + 1, v_quat[1]); + gpuAtomicAdd(v_quats + 2, v_quat[2]); + gpuAtomicAdd(v_quats + 3, v_quat[3]); + gpuAtomicAdd(v_scales, v_scale[0]); + gpuAtomicAdd(v_scales + 1, v_scale[1]); + } + } +} + +std::tuple +fully_fused_projection_packed_bwd_2dgs_tensor( + // fwd inputs + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 4] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + // fwd outputs + const torch::Tensor &camera_ids, // [nnz] + const torch::Tensor &gaussian_ids, // [nnz] + const torch::Tensor &ray_transforms, // [nnz, 3, 3] + // grad outputs + const torch::Tensor &v_means2d, // [nnz, 2] + const torch::Tensor &v_depths, // [nnz] + const torch::Tensor &v_ray_transforms, // [nnz, 3, 3] + const torch::Tensor &v_normals, // [nnz, 3] + const bool viewmats_requires_grad, + const bool sparse_grad +) { + + GSPLAT_DEVICE_GUARD(means); + GSPLAT_CHECK_INPUT(means); + GSPLAT_CHECK_INPUT(quats); + GSPLAT_CHECK_INPUT(scales); + GSPLAT_CHECK_INPUT(viewmats); + GSPLAT_CHECK_INPUT(Ks); + GSPLAT_CHECK_INPUT(camera_ids); + GSPLAT_CHECK_INPUT(gaussian_ids); + GSPLAT_CHECK_INPUT(ray_transforms); + GSPLAT_CHECK_INPUT(v_means2d); + GSPLAT_CHECK_INPUT(v_depths); + GSPLAT_CHECK_INPUT(v_normals); + GSPLAT_CHECK_INPUT(v_ray_transforms); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + uint32_t nnz = camera_ids.size(0); + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + torch::Tensor v_means, v_quats, v_scales, v_viewmats; + if (sparse_grad) { + v_means = torch::zeros({nnz, 3}, means.options()); + + v_quats = torch::zeros({nnz, 4}, quats.options()); + v_scales = torch::zeros({nnz, 3}, scales.options()); + + if (viewmats_requires_grad) { + v_viewmats = torch::zeros({C, 4, 4}, viewmats.options()); + } + + } else { + v_means = torch::zeros_like(means); + + v_quats = torch::zeros_like(quats); + v_scales = torch::zeros_like(scales); + + if (viewmats_requires_grad) { + v_viewmats = torch::zeros_like(viewmats); + } + } + if (nnz) { + + fully_fused_projection_packed_bwd_2dgs_kernel + <<<(nnz + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS, + GSPLAT_N_THREADS, + 0, + stream>>>( + C, + N, + nnz, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + camera_ids.data_ptr(), + gaussian_ids.data_ptr(), + ray_transforms.data_ptr(), + v_means2d.data_ptr(), + v_depths.data_ptr(), + v_normals.data_ptr(), + sparse_grad, + v_ray_transforms.data_ptr(), + v_means.data_ptr(), + v_quats.data_ptr(), + v_scales.data_ptr(), + viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr + ); + } + return std::make_tuple(v_means, v_quats, v_scales, v_viewmats); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu new file mode 100644 index 000000000..34ac796fb --- /dev/null +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu @@ -0,0 +1,328 @@ + +#include "bindings.h" +#include "helpers.cuh" +#include "utils.cuh" + +#include +#include +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Projection of Gaussians (Batched) Forward Pass 2DGS + ****************************************************************************/ + +template +__global__ void fully_fused_projection_packed_fwd_2dgs_kernel( + const uint32_t C, + const uint32_t N, + const T *__restrict__ means, // [N, 3] + const T *__restrict__ quats, // [N, 4] + const T *__restrict__ scales, // [N, 3] + const T *__restrict__ viewmats, // [C, 4, 4] + const T *__restrict__ Ks, // [C, 3, 3] + const int32_t image_width, + const int32_t image_height, + const T near_plane, + const T far_plane, + const T radius_clip, + const int32_t + *__restrict__ block_accum, // [C * blocks_per_row] packing helper + int32_t *__restrict__ block_cnts, // [C * blocks_per_row] packing helper + // outputs + int32_t *__restrict__ indptr, // [C + 1] + int64_t *__restrict__ camera_ids, // [nnz] + int64_t *__restrict__ gaussian_ids, // [nnz] + int32_t *__restrict__ radii, // [nnz] + T *__restrict__ means2d, // [nnz, 2] + T *__restrict__ depths, // [nnz] + T *__restrict__ ray_transforms, // [nnz, 3, 3] + T *__restrict__ normals // [nnz, 3] +) { + int32_t blocks_per_row = gridDim.x; + + int32_t row_idx = blockIdx.y; // cid + int32_t block_col_idx = blockIdx.x; + int32_t block_idx = row_idx * blocks_per_row + block_col_idx; + + int32_t col_idx = block_col_idx * blockDim.x + threadIdx.x; // gid + + bool valid = (row_idx < C) && (col_idx < N); + + // check if points are with camera near and far plane + vec3 mean_c; + mat3 R; + if (valid) { + // shift pointers to the current camera and gaussian + means += col_idx * 3; + viewmats += row_idx * 16; + + // glm is column-major but input is row-major + R = mat3( + viewmats[0], + viewmats[4], + viewmats[8], // 1st column + viewmats[1], + viewmats[5], + viewmats[9], // 2nd column + viewmats[2], + viewmats[6], + viewmats[10] // 3rd column + ); + vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]); + + // transform Gaussian center to camera space + pos_world_to_cam(R, t, glm::make_vec3(means), mean_c); + if (mean_c.z < near_plane || mean_c.z > far_plane) { + valid = false; + } + } + + vec2 mean2d; + mat3 M; + T radius; + vec3 normal; + if (valid) { + // build ray transformation matrix and transform from world space to + // camera space + quats += col_idx * 4; + scales += col_idx * 3; + + mat3 RS_camera = + R * quat_to_rotmat(glm::make_vec4(quats)) * + mat3(scales[0], 0.0, 0.0, 0.0, scales[1], 0.0, 0.0, 0.0, 1.0); + ; + mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c); + + mat3 world_2_pix = + mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0); + M = glm::transpose(WH) * world_2_pix; + + // compute AABB + const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]); + const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]); + const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]); + + const vec3 temp_point = vec3(1.0f, 1.0f, -1.0f); + const T distance = sum(temp_point * M2 * M2); + + if (distance == 0.0f) + valid = false; + + const vec3 f = (1 / distance) * temp_point; + mean2d = vec2(sum(f * M0 * M2), sum(f * M1 * M2)); + + const vec2 temp = {sum(f * M0 * M0), sum(f * M1 * M1)}; + const vec2 half_extend = mean2d * mean2d - temp; + radius = ceil(3.f * sqrt(max(1e-4, max(half_extend.x, half_extend.y)))); + + if (radius <= radius_clip) { + valid = false; + } + + // mask out gaussians outside the image region + if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width || + mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) { + valid = false; + } + + // normal dual visible + normal = RS_camera[2]; + T multipler = glm::dot(-normal, mean_c) > 0 ? 1 : -1; + normal *= multipler; + } + + int32_t thread_data = static_cast(valid); + if (block_cnts != nullptr) { + // First pass: compute the block-wide sum + int32_t aggregate; + if (__syncthreads_or(thread_data)) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + aggregate = BlockReduce(temp_storage).Sum(thread_data); + } else { + aggregate = 0; + } + if (threadIdx.x == 0) { + block_cnts[block_idx] = aggregate; + } + } else { + // Second pass: write out the indices of the non zero elements + if (__syncthreads_or(thread_data)) { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data); + } + if (valid) { + if (block_idx > 0) { + int32_t offset = block_accum[block_idx - 1]; + thread_data += offset; + } + // write to outputs + camera_ids[thread_data] = row_idx; // cid + gaussian_ids[thread_data] = col_idx; // gid + radii[thread_data] = (int32_t)radius; + means2d[thread_data * 2] = mean2d.x; + means2d[thread_data * 2 + 1] = mean2d.y; + depths[thread_data] = mean_c.z; + ray_transforms[thread_data * 9] = M[0][0]; + ray_transforms[thread_data * 9 + 1] = M[0][1]; + ray_transforms[thread_data * 9 + 2] = M[0][2]; + ray_transforms[thread_data * 9 + 3] = M[1][0]; + ray_transforms[thread_data * 9 + 4] = M[1][1]; + ray_transforms[thread_data * 9 + 5] = M[1][2]; + ray_transforms[thread_data * 9 + 6] = M[2][0]; + ray_transforms[thread_data * 9 + 7] = M[2][1]; + ray_transforms[thread_data * 9 + 8] = M[2][2]; + normals[thread_data * 3] = normal.x; + normals[thread_data * 3 + 1] = normal.y; + normals[thread_data * 3 + 2] = normal.z; + } + // lane 0 of the first block in each row writes the indptr + if (threadIdx.x == 0 && block_col_idx == 0) { + if (row_idx == 0) { + indptr[0] = 0; + indptr[C] = block_accum[C * blocks_per_row - 1]; + } else { + indptr[row_idx] = block_accum[block_idx - 1]; + } + } + } +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +fully_fused_projection_packed_fwd_2dgs_tensor( + const torch::Tensor &means, // [N, 3] + const torch::Tensor &quats, // [N, 3] + const torch::Tensor &scales, // [N, 3] + const torch::Tensor &viewmats, // [C, 4, 4] + const torch::Tensor &Ks, // [C, 3, 3] + const uint32_t image_width, + const uint32_t image_height, + const float near_plane, + const float far_plane, + const float radius_clip +) { + GSPLAT_DEVICE_GUARD(means); + GSPLAT_CHECK_INPUT(means); + GSPLAT_CHECK_INPUT(quats); + GSPLAT_CHECK_INPUT(scales); + GSPLAT_CHECK_INPUT(viewmats); + GSPLAT_CHECK_INPUT(Ks); + + uint32_t N = means.size(0); // number of gaussians + uint32_t C = viewmats.size(0); // number of cameras + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + auto opt = means.options().dtype(torch::kInt32); + + uint32_t nrows = C; + uint32_t ncols = N; + uint32_t blocks_per_row = (ncols + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS; + + dim3 threads = {GSPLAT_N_THREADS, 1, 1}; + // limit on the number of blocks: [2**31 - 1, 65535, 65535] + dim3 blocks = {blocks_per_row, nrows, 1}; + + // first pass + int32_t nnz; + torch::Tensor block_accum; + if (C && N) { + torch::Tensor block_cnts = torch::empty({nrows * blocks_per_row}, opt); + fully_fused_projection_packed_fwd_2dgs_kernel + <<>>( + C, + N, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + near_plane, + far_plane, + radius_clip, + nullptr, + block_cnts.data_ptr(), + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr + ); + block_accum = torch::cumsum(block_cnts, 0, torch::kInt32); + nnz = block_accum[-1].item(); + } else { + nnz = 0; + } + + // second pass + torch::Tensor indptr = torch::empty({C + 1}, opt); + torch::Tensor camera_ids = torch::empty({nnz}, opt.dtype(torch::kInt64)); + torch::Tensor gaussian_ids = torch::empty({nnz}, opt.dtype(torch::kInt64)); + torch::Tensor radii = + torch::empty({nnz}, means.options().dtype(torch::kInt32)); + torch::Tensor means2d = torch::empty({nnz, 2}, means.options()); + torch::Tensor depths = torch::empty({nnz}, means.options()); + torch::Tensor ray_transforms = torch::empty({nnz, 3, 3}, means.options()); + torch::Tensor normals = torch::empty({nnz, 3}, means.options()); + + if (nnz) { + fully_fused_projection_packed_fwd_2dgs_kernel + <<>>( + C, + N, + means.data_ptr(), + quats.data_ptr(), + scales.data_ptr(), + viewmats.data_ptr(), + Ks.data_ptr(), + image_width, + image_height, + near_plane, + far_plane, + radius_clip, + block_accum.data_ptr(), + nullptr, + indptr.data_ptr(), + camera_ids.data_ptr(), + gaussian_ids.data_ptr(), + radii.data_ptr(), + means2d.data_ptr(), + depths.data_ptr(), + ray_transforms.data_ptr(), + normals.data_ptr() + ); + } else { + indptr.fill_(0); + } + + return std::make_tuple( + indptr, + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + ray_transforms, + normals + ); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu index 27a337a34..12f21c611 100644 --- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu +++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu @@ -1,4 +1,3 @@ - #include "bindings.h" #include "helpers.cuh" #include "utils.cuh" diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh index 0e4995bad..93ff6bd45 100644 --- a/gsplat/cuda/csrc/helpers.cuh +++ b/gsplat/cuda/csrc/helpers.cuh @@ -73,6 +73,10 @@ inline __device__ void warpMax(ScalarT &val, WarpT &warp) { val = cg::reduce(warp, val, cg::greater()); } +template __forceinline__ __device__ T sum(vec3 a) { + return a.x + a.y + a.z; +} + } // namespace gsplat -#endif // GSPLAT_CUDA_HELPERS_H \ No newline at end of file +#endif // GSPLAT_CUDA_HELPERS_H diff --git a/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu new file mode 100644 index 000000000..1432a1666 --- /dev/null +++ b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu @@ -0,0 +1,339 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include "utils.cuh" +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Rasterization to Indices in Range 2DGS + ****************************************************************************/ + +template +__global__ void rasterize_to_indices_in_range_kernel( + const uint32_t range_start, + const uint32_t range_end, + const uint32_t C, + const uint32_t N, + const uint32_t n_isects, + const vec2 *__restrict__ means2d, // [C, N, 2] + const T *__restrict__ ray_transforms, // [C, N, 3, 3] + const T *__restrict__ opacities, // [C, N] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + const T *__restrict__ transmittances, // [C, image_height, image_width] + const int32_t *__restrict__ chunk_starts, // [C, image_height, image_width] + int32_t *__restrict__ chunk_cnts, // [C, image_height, image_width] + int64_t *__restrict__ gaussian_ids, // [n_elems] + int64_t *__restrict__ pixel_ids // [n_elems] +) { + // each thread draws one pixel, but also timeshares caching gaussians in a + // shared tile + + auto block = cg::this_thread_block(); + uint32_t camera_id = block.group_index().x; + uint32_t tile_id = + block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + // move pointers to the current camera + tile_offsets += camera_id * tile_height * tile_width; + transmittances += camera_id * image_height * image_width; + + T px = (T)j + 0.5f; + T py = (T)i + 0.5f; + int32_t pix_id = i * image_width + j; + + // return if out of bounds + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + bool done = !inside; + + bool first_pass = chunk_starts == nullptr; + int32_t base; + if (!first_pass && inside) { + chunk_starts += camera_id * image_height * image_width; + base = chunk_starts[pix_id]; + } + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t isect_range_start = tile_offsets[tile_id]; + int32_t isect_range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + uint32_t num_batches = + (isect_range_end - isect_range_start + block_size - 1) / block_size; + + if (range_start >= num_batches) { + // this entire tile has been processed in the previous iterations + // so we don't need to do anything. + return; + } + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *u_Ms_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size] + ); // [block_size] + vec3 *v_Ms_batch = + reinterpret_cast *>(&u_Ms_batch[block_size] + ); // [block_size] + vec3 *w_Ms_batch = + reinterpret_cast *>(&v_Ms_batch[block_size] + ); // [block_size] + + // current visibility left to render + // transmittance is gonna be used in the backward pass which requires a high + // numerical precision so we (should) use double for it. However double make + // bwd 1.5x slower so we stick with float for now. + T trans, next_trans; + if (inside) { + trans = transmittances[pix_id]; + next_trans = trans; + } + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + uint32_t tr = block.thread_rank(); + + int32_t cnt = 0; + for (uint32_t b = range_start; b < min(range_end, num_batches); ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = isect_range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < isect_range_end) { + int32_t g = flatten_ids[idx]; + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const T opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + u_Ms_batch[tr] = { + ray_transforms[g * 9], ray_transforms[g * 9 + 1], ray_transforms[g * 9 + 2] + }; + v_Ms_batch[tr] = { + ray_transforms[g * 9 + 3], ray_transforms[g * 9 + 4], ray_transforms[g * 9 + 5] + }; + w_Ms_batch[tr] = { + ray_transforms[g * 9 + 6], ray_transforms[g * 9 + 7], ray_transforms[g * 9 + 8] + }; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, isect_range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + const vec3 u_M = u_Ms_batch[t]; + const vec3 v_M = v_Ms_batch[t]; + const vec3 w_M = w_Ms_batch[t]; + const vec3 xy_opac = xy_opacity_batch[t]; + const T opac = xy_opac.z; + + const vec3 h_u = px * w_M - u_M; + const vec3 h_v = py * w_M - v_M; + + const vec3 ray_cross = glm::cross(h_u, h_v); + + if (ray_cross.z == 0.0) + continue; + + const vec2 s = { + ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z + }; + const T gauss_weight_3d = s.x * s.x + s.y * s.y; + + // Low pass filter + const vec2 d = {xy_opac.x - px, xy_opac.y - py}; + // 2D screen distance + const T gauss_weight_2d = + FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); + const T gauss_weight = min(gauss_weight_3d, gauss_weight_2d); + + const T sigma = 0.5f * gauss_weight; + T alpha = min(0.999f, opac * __expf(-sigma)); + + if (sigma < 0.f || alpha < 1.f / 255.f) { + continue; + } + + next_trans = trans * (1.0f - alpha); + if (next_trans <= 1e-4) { // this pixel is done: exclusive + done = true; + break; + } + + if (first_pass) { + // First pass of this function we count the number of gaussians + // that contribute to each pixel + cnt += 1; + } else { + // Second pass we write out the gaussian ids and pixel ids + int32_t g = id_batch[t]; // flatten index in [C * N] + gaussian_ids[base + cnt] = g % N; + pixel_ids[base + cnt] = + pix_id + camera_id * image_height * image_width; + cnt += 1; + } + + trans = next_trans; + } + } + + if (inside && first_pass) { + chunk_cnts += camera_id * image_height * image_width; + chunk_cnts[pix_id] = cnt; + } +} + +std::tuple +rasterize_to_indices_in_range_2dgs_tensor( + const uint32_t range_start, + const uint32_t range_end, // iteration steps + const torch::Tensor transmittances, // [C, image_height, image_width] + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] + const torch::Tensor &opacities, // [C, N] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +) { + GSPLAT_DEVICE_GUARD(means2d); + GSPLAT_CHECK_INPUT(means2d); + GSPLAT_CHECK_INPUT(ray_transforms); + GSPLAT_CHECK_INPUT(opacities); + GSPLAT_CHECK_INPUT(tile_offsets); + GSPLAT_CHECK_INPUT(flatten_ids); + + uint32_t C = means2d.size(0); // number of cameras + uint32_t N = means2d.size(1); // number of gaussians + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + const uint32_t shared_mem = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + + sizeof(vec3) + sizeof(vec3)); + if (cudaFuncSetAttribute( + rasterize_to_indices_in_range_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem + ) != cudaSuccess) { + AT_ERROR( + "Failed to set maximum shared memory size (requested ", + shared_mem, + " bytes), try lowering tile_size." + ); + } + + // First pass: count the number of gaussians that contribute to each pixel + int64_t n_elems; + torch::Tensor chunk_starts; + if (n_isects) { + torch::Tensor chunk_cnts = torch::zeros( + {C * image_height * image_width}, + means2d.options().dtype(torch::kInt32) + ); + rasterize_to_indices_in_range_kernel + <<>>( + range_start, + range_end, + C, + N, + n_isects, + reinterpret_cast *>(means2d.data_ptr()), + ray_transforms.data_ptr(), + opacities.data_ptr(), + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + transmittances.data_ptr(), + nullptr, + chunk_cnts.data_ptr(), + nullptr, + nullptr + ); + + torch::Tensor cumsum = + torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type()); + n_elems = cumsum[-1].item(); + chunk_starts = cumsum - chunk_cnts; + } else { + n_elems = 0; + } + + // Second pass: allocate memory and write out the gaussian and pixel ids. + torch::Tensor gaussian_ids = + torch::empty({n_elems}, means2d.options().dtype(torch::kInt64)); + torch::Tensor pixel_ids = + torch::empty({n_elems}, means2d.options().dtype(torch::kInt64)); + if (n_elems) { + rasterize_to_indices_in_range_kernel + <<>>( + range_start, + range_end, + C, + N, + n_isects, + reinterpret_cast *>(means2d.data_ptr()), + ray_transforms.data_ptr(), + opacities.data_ptr(), + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + transmittances.data_ptr(), + chunk_starts.data_ptr(), + nullptr, + gaussian_ids.data_ptr(), + pixel_ids.data_ptr() + ); + } + return std::make_tuple(gaussian_ids, pixel_ids); +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu new file mode 100644 index 000000000..ddf508283 --- /dev/null +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu @@ -0,0 +1,711 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include "utils.cuh" +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Rasterization to Pixels Backward Pass 2DGS + ****************************************************************************/ +template +__global__ void rasterize_to_pixels_bwd_2dgs_kernel( + const uint32_t C, + const uint32_t N, + const uint32_t n_isects, + const bool packed, + // fwd inputs + const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2] + const S *__restrict__ ray_transforms, // [C, N, 3] or [nnz, 3] + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] + const S *__restrict__ opacities, // [C, N] or [nnz] + const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + // fwd outputs + const S *__restrict__ render_colors, // [C, image_height, image_width, + // COLOR_DIM] + const S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + const int32_t *__restrict__ last_ids, // [C, image_height, image_width] + const int32_t *__restrict__ median_ids, // [C, image_height, image_width] + // grad outputs + const S *__restrict__ v_render_colors, // [C, image_height, image_width, + // COLOR_DIM] + const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1] + const S *__restrict__ v_render_normals, // [C, image_height, image_width, 3] + const S *__restrict__ v_render_distort, // [C, image_height, image_width, 1] + const S *__restrict__ v_render_median, // [C, image_height, image_width, 1] + // grad inputs + vec2 *__restrict__ v_means2d_abs, // [C, N, 2] or [nnz, 2] + vec2 *__restrict__ v_means2d, // [C, N, 2] or [nnz, 2] + S *__restrict__ v_ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + S *__restrict__ v_colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + S *__restrict__ v_opacities, // [C, N] or [nnz] + S *__restrict__ v_normals, // [C, N, 3] or [nnz, 3] + S *__restrict__ v_densify +) { + auto block = cg::this_thread_block(); + uint32_t camera_id = block.group_index().x; + uint32_t tile_id = + block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width; + median_ids += camera_id * image_height * image_width; + v_render_colors += camera_id * image_height * image_width * COLOR_DIM; + v_render_alphas += camera_id * image_height * image_width; + v_render_normals += camera_id * image_height * image_width * 3; + v_render_median += camera_id * image_height * image_width; + if (backgrounds != nullptr) { + backgrounds += camera_id * COLOR_DIM; + } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } + if (v_render_distort != nullptr) { + v_render_distort += camera_id * image_height * image_width; + } + + // when the mask is provided, do nothing and return if + // this tile is labeled as False + if (masks != nullptr && !masks[tile_id]) { + return; + } + + const S px = (S)j + 0.5f; + const S py = (S)i + 0.5f; + // clamp this value to the last pixel + const int32_t pix_id = + min(i * image_width + j, image_width * image_height - 1); + + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + const uint32_t num_batches = + (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *u_Ms_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size] + ); // [block_size] + vec3 *v_Ms_batch = + reinterpret_cast *>(&u_Ms_batch[block_size] + ); // [block_size] + vec3 *w_Ms_batch = + reinterpret_cast *>(&v_Ms_batch[block_size] + ); // [block_size] + S *rgbs_batch = (S *)&w_Ms_batch[block_size]; // [block_size * COLOR_DIM] + S *normals_batch = &rgbs_batch[block_size * COLOR_DIM]; // [block_size * 3] + + // this is the T AFTER the last gaussian in this pixel + S T_final = 1.0f - render_alphas[pix_id]; + S T = T_final; + // the contribution from gaussians behind the current one + S buffer[COLOR_DIM] = {0.f}; + S buffer_normals[3] = {0.f}; + // index of last gaussian to contribute to this pixel + const int32_t bin_final = inside ? last_ids[pix_id] : 0; + // index of gaussian that contributes to median depth + const int32_t median_idx = inside ? median_ids[pix_id] : 0; + + // df/d_out for this pixel + S v_render_c[COLOR_DIM]; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_render_c[k] = v_render_colors[pix_id * COLOR_DIM + k]; + } + const S v_render_a = v_render_alphas[pix_id]; + S v_render_n[3]; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + v_render_n[k] = v_render_normals[pix_id * 3 + k]; + } + + // prepare for distortion + S v_distort = 0.f; + S accum_d, accum_w; + S accum_d_buffer, accum_w_buffer, distort_buffer; + if (v_render_distort != nullptr) { + v_distort = v_render_distort[pix_id]; + // last channel of render_colors is accumulated depth + accum_d_buffer = render_colors[pix_id * COLOR_DIM + COLOR_DIM - 1]; + accum_d = accum_d_buffer; + accum_w_buffer = render_alphas[pix_id]; + accum_w = accum_w_buffer; + distort_buffer = 0.f; + } + + // median depth gradients + S v_median = v_render_median[pix_id]; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing + const uint32_t tr = block.thread_rank(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + const int32_t warp_bin_final = + cg::reduce(warp, bin_final, cg::greater()); + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before writing next batch of shared mem + block.sync(); + + // each thread fetch 1 gaussian from back to front + // 0 index will be furthest back in batch + // index of gaussian to load + // batch end is the index of the last gaussian in the batch + // These values can be negative so must be int32 instead of uint32 + const int32_t batch_end = range_end - 1 - block_size * b; + const int32_t batch_size = min(block_size, batch_end + 1 - range_start); + const int32_t idx = batch_end - tr; + if (idx >= range_start) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + u_Ms_batch[tr] = { + ray_transforms[g * 9], ray_transforms[g * 9 + 1], ray_transforms[g * 9 + 2] + }; + v_Ms_batch[tr] = { + ray_transforms[g * 9 + 3], ray_transforms[g * 9 + 4], ray_transforms[g * 9 + 5] + }; + w_Ms_batch[tr] = { + ray_transforms[g * 9 + 6], ray_transforms[g * 9 + 7], ray_transforms[g * 9 + 8] + }; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + rgbs_batch[tr * COLOR_DIM + k] = colors[g * COLOR_DIM + k]; + } + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + normals_batch[tr * 3 + k] = normals[g * 3 + k]; + } + } + // wait for other threads to collect the gaussians in batch + block.sync(); + // process gaussians in the current batch for this pixel + // 0 index is the furthest back gaussian in the batch + for (uint32_t t = max(0, batch_end - warp_bin_final); t < batch_size; + ++t) { + bool valid = inside; + if (batch_end - t > bin_final) { + valid = 0; + } + S alpha; + S opac; + S vis; + S gauss_weight_3d; + S gauss_weight_2d; + S gauss_weight; + vec2 s; + vec2 d; + vec3 h_u; + vec3 h_v; + vec3 ray_cross; + vec3 w_M; + if (valid) { + vec3 xy_opac = xy_opacity_batch[t]; + opac = xy_opac.z; + const vec3 u_M = u_Ms_batch[t]; + const vec3 v_M = v_Ms_batch[t]; + w_M = w_Ms_batch[t]; + + h_u = px * w_M - u_M; + h_v = py * w_M - v_M; + + ray_cross = glm::cross(h_u, h_v); + + // no ray_crossion + if (ray_cross.z == 0.0) + valid = false; + s = {ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z}; + + gauss_weight_3d = s.x * s.x + s.y * s.y; + d = {xy_opac.x - px, xy_opac.y - py}; + gauss_weight_2d = FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); + gauss_weight = min(gauss_weight_3d, gauss_weight_2d); + + const S sigma = 0.5f * gauss_weight; + vis = __expf(-sigma); + alpha = min(0.999f, opac * vis); + if (sigma < 0.f || alpha < 1.f / 255.f) { + valid = false; + } + } + + // if all threads are inactive in this warp, skip this loop + if (!warp.any(valid)) { + continue; + } + S v_rgb_local[COLOR_DIM] = {0.f}; + S v_normal_local[3] = {0.f}; + vec3 v_u_M_local = {0.f, 0.f, 0.f}; + vec3 v_v_M_local = {0.f, 0.f, 0.f}; + vec3 v_w_M_local = {0.f, 0.f, 0.f}; + vec2 v_xy_local = {0.f, 0.f}; + vec2 v_xy_abs_local = {0.f, 0.f}; + S v_opacity_local = 0.f; + // initialize everything to 0, only set if the lane is valid + if (valid) { + // gradient contribution from median depth + if (batch_end - t == median_idx) { + v_rgb_local[COLOR_DIM - 1] += v_median; + } + + // compute the current T for this gaussian + S ra = 1.0f / (1.0f - alpha); + T *= ra; + // update v_rgb for this gaussian + const S fac = alpha * T; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_rgb_local[k] += fac * v_render_c[k]; + } + // contribution from this pixel + S v_alpha = 0.f; + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + v_alpha += + (rgbs_batch[t * COLOR_DIM + k] * T - buffer[k] * ra) * + v_render_c[k]; + } + + // update v_normal for this gaussian + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + v_normal_local[k] = fac * v_render_n[k]; + } + + for (uint32_t k = 0; k < 3; ++k) { + v_alpha += (normals_batch[t * 3 + k] * T - + buffer_normals[k] * ra) * + v_render_n[k]; + } + + v_alpha += T_final * ra * v_render_a; + + // contribution from background pixel + if (backgrounds != nullptr) { + S accum = 0.f; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + accum += backgrounds[k] * v_render_c[k]; + } + v_alpha += -T_final * ra * accum; + } + + // contribution from distortion + if (v_render_distort != nullptr) { + // last channel of colors is depth + S depth = rgbs_batch[t * COLOR_DIM + COLOR_DIM - 1]; + S dl_dw = + 2.0f * + (2.0f * (depth * accum_w_buffer - accum_d_buffer) + + (accum_d - depth * accum_w)); + // df / d(alpha) + v_alpha += (dl_dw * T - distort_buffer * ra) * v_distort; + accum_d_buffer -= fac * depth; + accum_w_buffer -= fac; + distort_buffer += dl_dw * fac; + // df / d(depth). put it in the last channel of v_rgb + v_rgb_local[COLOR_DIM - 1] += + 2.0f * fac * (2.0f - 2.0f * T - accum_w + fac) * + v_distort; + } + + //====== 2DGS ======// + if (opac * vis <= 0.999f) { + S v_depth = 0.f; + const S v_G = opac * v_alpha; + if (gauss_weight_3d <= gauss_weight_2d) { + const vec2 v_s = { + v_G * -vis * s.x + v_depth * w_M.x, + v_G * -vis * s.y + v_depth * w_M.y + }; + const vec3 v_z_w_M = {s.x, s.y, 1.0}; + const S v_sx_pz = v_s.x / ray_cross.z; + const S v_sy_pz = v_s.y / ray_cross.z; + const vec3 v_ray_cross = { + v_sx_pz, v_sy_pz, -(v_sx_pz * s.x + v_sy_pz * s.y) + }; + const vec3 v_h_u = glm::cross(h_v, v_ray_cross); + const vec3 v_h_v = glm::cross(v_ray_cross, h_u); + + v_u_M_local = {-v_h_u.x, -v_h_u.y, -v_h_u.z}; + v_v_M_local = {-v_h_v.x, -v_h_v.y, -v_h_v.z}; + v_w_M_local = { + px * v_h_u.x + py * v_h_v.x + v_depth * v_z_w_M.x, + px * v_h_u.y + py * v_h_v.y + v_depth * v_z_w_M.y, + px * v_h_u.z + py * v_h_v.z + v_depth * v_z_w_M.z + }; + + } else { + const S v_G_ddelx = -vis * FILTER_INV_SQUARE * d.x; + const S v_G_ddely = -vis * FILTER_INV_SQUARE * d.y; + v_xy_local = {v_G * v_G_ddelx, v_G * v_G_ddely}; + if (v_means2d_abs != nullptr) { + v_xy_abs_local = { + abs(v_xy_local.x), abs(v_xy_local.y) + }; + } + } + v_opacity_local = vis * v_alpha; + } + + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + buffer[k] += rgbs_batch[t * COLOR_DIM + k] * fac; + } + + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + buffer_normals[k] += normals_batch[t * 3 + k] * fac; + } + } + warpSum(v_rgb_local, warp); + warpSum<3, S>(v_normal_local, warp); + warpSum(v_xy_local, warp); + warpSum(v_u_M_local, warp); + warpSum(v_v_M_local, warp); + warpSum(v_w_M_local, warp); + if (v_means2d_abs != nullptr) { + warpSum(v_xy_abs_local, warp); + } + warpSum(v_opacity_local, warp); + int32_t g = id_batch[t]; // flatten index in [C * N] or [nnz] + if (warp.thread_rank() == 0) { + S *v_rgb_ptr = (S *)(v_colors) + COLOR_DIM * g; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + gpuAtomicAdd(v_rgb_ptr + k, v_rgb_local[k]); + } + + S *v_normal_ptr = (S *)(v_normals) + 3 * g; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + gpuAtomicAdd(v_normal_ptr + k, v_normal_local[k]); + } + + S *v_ray_transforms_ptr = (S *)(v_ray_transforms) + 9 * g; + gpuAtomicAdd(v_ray_transforms_ptr, v_u_M_local.x); + gpuAtomicAdd(v_ray_transforms_ptr + 1, v_u_M_local.y); + gpuAtomicAdd(v_ray_transforms_ptr + 2, v_u_M_local.z); + gpuAtomicAdd(v_ray_transforms_ptr + 3, v_v_M_local.x); + gpuAtomicAdd(v_ray_transforms_ptr + 4, v_v_M_local.y); + gpuAtomicAdd(v_ray_transforms_ptr + 5, v_v_M_local.z); + gpuAtomicAdd(v_ray_transforms_ptr + 6, v_w_M_local.x); + gpuAtomicAdd(v_ray_transforms_ptr + 7, v_w_M_local.y); + gpuAtomicAdd(v_ray_transforms_ptr + 8, v_w_M_local.z); + + S *v_xy_ptr = (S *)(v_means2d) + 2 * g; + gpuAtomicAdd(v_xy_ptr, v_xy_local.x); + gpuAtomicAdd(v_xy_ptr + 1, v_xy_local.y); + + if (v_means2d_abs != nullptr) { + S *v_xy_abs_ptr = (S *)(v_means2d_abs) + 2 * g; + gpuAtomicAdd(v_xy_abs_ptr, v_xy_abs_local.x); + gpuAtomicAdd(v_xy_abs_ptr + 1, v_xy_abs_local.y); + } + + gpuAtomicAdd(v_opacities + g, v_opacity_local); + } + + if (valid) { + S *v_densify_ptr = (S *)(v_densify) + 2 * g; + S *v_ray_transforms_ptr = (S *)(v_ray_transforms) + 9 * g; + S depth = w_M.z; + v_densify_ptr[0] = v_ray_transforms_ptr[2] * depth; + v_densify_ptr[1] = v_ray_transforms_ptr[5] * depth; + } + } + } +} + +template +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +call_kernel_with_dim( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const torch::Tensor &densify, + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // ray_crossions + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor + &render_colors, // [C, image_height, image_width, COLOR_DIM] + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_normals, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_distort, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median, // [C, image_height, image_width, 1] + // options + bool absgrad +) { + + GSPLAT_DEVICE_GUARD(means2d); + GSPLAT_CHECK_INPUT(means2d); + GSPLAT_CHECK_INPUT(ray_transforms); + GSPLAT_CHECK_INPUT(colors); + GSPLAT_CHECK_INPUT(opacities); + GSPLAT_CHECK_INPUT(normals); + GSPLAT_CHECK_INPUT(densify); + GSPLAT_CHECK_INPUT(tile_offsets); + GSPLAT_CHECK_INPUT(flatten_ids); + GSPLAT_CHECK_INPUT(render_colors); + GSPLAT_CHECK_INPUT(render_alphas); + GSPLAT_CHECK_INPUT(last_ids); + GSPLAT_CHECK_INPUT(median_ids); + GSPLAT_CHECK_INPUT(v_render_colors); + GSPLAT_CHECK_INPUT(v_render_alphas); + GSPLAT_CHECK_INPUT(v_render_normals); + GSPLAT_CHECK_INPUT(v_render_distort); + GSPLAT_CHECK_INPUT(v_render_median); + if (backgrounds.has_value()) { + GSPLAT_CHECK_INPUT(backgrounds.value()); + } + if (masks.has_value()) { + GSPLAT_CHECK_INPUT(masks.value()); + } + + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t n_isects = flatten_ids.size(0); + uint32_t COLOR_DIM = colors.size(-1); + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + torch::Tensor v_means2d = torch::zeros_like(means2d); + torch::Tensor v_ray_transforms = torch::zeros_like(ray_transforms); + torch::Tensor v_colors = torch::zeros_like(colors); + torch::Tensor v_normals = torch::zeros_like(normals); + torch::Tensor v_opacities = torch::zeros_like(opacities); + torch::Tensor v_means2d_abs; + if (absgrad) { + v_means2d_abs = torch::zeros_like(means2d); + } + torch::Tensor v_densify = torch::zeros_like(densify); + + if (n_isects) { + const uint32_t shared_mem = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + + sizeof(vec3) + sizeof(vec3) + + sizeof(float) * COLOR_DIM + sizeof(float) * 3); + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + if (cudaFuncSetAttribute( + rasterize_to_pixels_bwd_2dgs_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem + ) != cudaSuccess) { + AT_ERROR( + "Failed to set maximum shared memory size (requested ", + shared_mem, + " bytes), try lowering tile_size." + ); + } + rasterize_to_pixels_bwd_2dgs_kernel + <<>>( + C, + N, + n_isects, + packed, + reinterpret_cast *>(means2d.data_ptr()), + ray_transforms.data_ptr(), + colors.data_ptr(), + normals.data_ptr(), + opacities.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() + : nullptr, + masks.has_value() ? masks.value().data_ptr() : nullptr, + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + render_colors.data_ptr(), + render_alphas.data_ptr(), + last_ids.data_ptr(), + median_ids.data_ptr(), + v_render_colors.data_ptr(), + v_render_alphas.data_ptr(), + v_render_normals.data_ptr(), + v_render_distort.data_ptr(), + v_render_median.data_ptr(), + absgrad ? reinterpret_cast *>( + v_means2d_abs.data_ptr() + ) + : nullptr, + reinterpret_cast *>(v_means2d.data_ptr()), + v_ray_transforms.data_ptr(), + v_colors.data_ptr(), + v_opacities.data_ptr(), + v_normals.data_ptr(), + v_densify.data_ptr() + ); + } + + return std::make_tuple( + v_means2d_abs, + v_means2d, + v_ray_transforms, + v_colors, + v_opacities, + v_normals, + v_densify + ); +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_bwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + const torch::Tensor &colors, // [C, N, 3] or [nnz, 3] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const torch::Tensor &densify, + const at::optional &backgrounds, // [C, 3] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // ray_crossions + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids, // [n_isects] + // forward outputs + const torch::Tensor + &render_colors, // [C, image_height, image_width, COLOR_DIM] + const torch::Tensor &render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &last_ids, // [C, image_height, image_width] + const torch::Tensor &median_ids, // [C, image_height, image_width] + // gradients of outputs + const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_normals, // [C, image_height, image_width, 3] + const torch::Tensor &v_render_distort, // [C, image_height, image_width, 1] + const torch::Tensor &v_render_median, // [C, image_height, image_width, 1] + // options + bool absgrad +) { + + GSPLAT_CHECK_INPUT(colors); + uint32_t COLOR_DIM = colors.size(-1); + +#define __GS__CALL_(N) \ + case N: \ + return call_kernel_with_dim( \ + means2d, \ + ray_transforms, \ + colors, \ + opacities, \ + normals, \ + densify, \ + backgrounds, \ + masks, \ + image_width, \ + image_height, \ + tile_size, \ + tile_offsets, \ + flatten_ids, \ + render_colors, \ + render_alphas, \ + last_ids, \ + median_ids, \ + v_render_colors, \ + v_render_alphas, \ + v_render_normals, \ + v_render_distort, \ + v_render_median, \ + absgrad \ + ); + + switch (COLOR_DIM) { + __GS__CALL_(1) + __GS__CALL_(2) + __GS__CALL_(3) + __GS__CALL_(4) + __GS__CALL_(5) + __GS__CALL_(8) + __GS__CALL_(9) + __GS__CALL_(16) + __GS__CALL_(17) + __GS__CALL_(32) + __GS__CALL_(33) + __GS__CALL_(64) + __GS__CALL_(65) + __GS__CALL_(128) + __GS__CALL_(129) + __GS__CALL_(256) + __GS__CALL_(257) + __GS__CALL_(512) + __GS__CALL_(513) + default: + AT_ERROR("Unsupported number of channels: ", COLOR_DIM); + } +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu new file mode 100644 index 000000000..f3705f5f9 --- /dev/null +++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_fwd.cu @@ -0,0 +1,491 @@ +#include "bindings.h" +#include "helpers.cuh" +#include "types.cuh" +#include "utils.cuh" +#include +#include +#include + +namespace gsplat { + +namespace cg = cooperative_groups; + +/**************************************************************************** + * Rasterization to Pixels Forward Pass 2DGS + ****************************************************************************/ + +template +__global__ void rasterize_to_pixels_fwd_2dgs_kernel( + const uint32_t C, + const uint32_t N, + const uint32_t n_isects, + const bool packed, + const vec2 *__restrict__ means2d, + const S *__restrict__ ray_transforms, + const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM] + const S *__restrict__ opacities, // [C, N] or [nnz] + const S *__restrict__ normals, // [C, N, 3] or [nnz, 3] + const S *__restrict__ backgrounds, // [C, COLOR_DIM] + const bool *__restrict__ masks, // [C, tile_height, tile_width] + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + const uint32_t tile_width, + const uint32_t tile_height, + const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width] + const int32_t *__restrict__ flatten_ids, // [n_isects] + S *__restrict__ render_colors, // [C, image_height, image_width, COLOR_DIM] + S *__restrict__ render_alphas, // [C, image_height, image_width, 1] + S *__restrict__ render_normals, // [C, image_height, image_width, 3] + S *__restrict__ render_distort, // [C, image_height, image_width, 1] + S *__restrict__ render_median, // [C, image_height, image_width, 1] + int32_t *__restrict__ last_ids, // [C, image_height, image_width] + int32_t *__restrict__ median_ids // [C, image_height, image_width] +) { + // each thread draws one pixel, but also timeshares caching gaussians in a + // shared tile + + auto block = cg::this_thread_block(); + int32_t camera_id = block.group_index().x; + int32_t tile_id = + block.group_index().y * tile_width + block.group_index().z; + uint32_t i = block.group_index().y * tile_size + block.thread_index().y; + uint32_t j = block.group_index().z * tile_size + block.thread_index().x; + + tile_offsets += camera_id * tile_height * tile_width; + render_colors += camera_id * image_height * image_width * COLOR_DIM; + render_alphas += camera_id * image_height * image_width; + last_ids += camera_id * image_height * image_width; + if (backgrounds != nullptr) { + backgrounds += camera_id * COLOR_DIM; + } + if (masks != nullptr) { + masks += camera_id * tile_height * tile_width; + } + + S px = (S)j + 0.5f; + S py = (S)i + 0.5f; + int32_t pix_id = i * image_width + j; + + // return if out of bounds + // keep not rasterizing threads around for reading data + bool inside = (i < image_height && j < image_width); + bool done = !inside; + + // when the mask is provided, render the background color and return + // if this tile is labeled as False + if (masks != nullptr && inside && !masks[tile_id]) { + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + render_colors[pix_id * COLOR_DIM + k] = + backgrounds == nullptr ? 0.0f : backgrounds[k]; + } + return; + } + + // have all threads in tile process the same gaussians in batches + // first collect gaussians between range.x and range.y in batches + // which gaussians to look through in this tile + int32_t range_start = tile_offsets[tile_id]; + int32_t range_end = + (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1) + ? n_isects + : tile_offsets[tile_id + 1]; + const uint32_t block_size = block.size(); + uint32_t num_batches = + (range_end - range_start + block_size - 1) / block_size; + + extern __shared__ int s[]; + int32_t *id_batch = (int32_t *)s; // [block_size] + vec3 *xy_opacity_batch = + reinterpret_cast *>(&id_batch[block_size]); // [block_size] + vec3 *u_Ms_batch = + reinterpret_cast *>(&xy_opacity_batch[block_size] + ); // [block_size] + vec3 *v_Ms_batch = + reinterpret_cast *>(&u_Ms_batch[block_size] + ); // [block_size] + vec3 *w_Ms_batch = + reinterpret_cast *>(&v_Ms_batch[block_size] + ); // [block_size] + + // current visibility left to render + // transmittance is gonna be used in the backward pass which requires a high + // numerical precision so we use double for it. However double make bwd 1.5x + // slower so we stick with float for now. + S T = 1.0f; + // index of most recent gaussian to write to this thread's pixel + uint32_t cur_idx = 0; + + // collect and process batches of gaussians + // each thread loads one gaussian at a time before rasterizing its + // designated pixel + uint32_t tr = block.thread_rank(); + + // Per-pixel distortion error proposed in Mip-NeRF 360. + // Implemented reference: + // https://github.com/nerfstudio-project/nerfacc/blob/master/nerfacc/losses.py#L7 + S distort = 0.f; + S accum_vis_depth = 0.f; // accumulate vis * depth + + // keep track of median depth contribution + S median_depth = 0.f; + uint32_t median_idx = 0.f; + + // TODO (WZ): merge pix_out and normal_out to + // S pix_out[COLOR_DIM + 3] = {0.f} + S pix_out[COLOR_DIM] = {0.f}; + S normal_out[3] = {0.f}; + for (uint32_t b = 0; b < num_batches; ++b) { + // resync all threads before beginning next batch + // end early if entire tile is done + if (__syncthreads_count(done) >= block_size) { + break; + } + + // each thread fetch 1 gaussian from front to back + // index of gaussian to load + uint32_t batch_start = range_start + block_size * b; + uint32_t idx = batch_start + tr; + if (idx < range_end) { + int32_t g = flatten_ids[idx]; // flatten index in [C * N] or [nnz] + id_batch[tr] = g; + const vec2 xy = means2d[g]; + const S opac = opacities[g]; + xy_opacity_batch[tr] = {xy.x, xy.y, opac}; + u_Ms_batch[tr] = { + ray_transforms[g * 9], ray_transforms[g * 9 + 1], ray_transforms[g * 9 + 2] + }; + v_Ms_batch[tr] = { + ray_transforms[g * 9 + 3], ray_transforms[g * 9 + 4], ray_transforms[g * 9 + 5] + }; + w_Ms_batch[tr] = { + ray_transforms[g * 9 + 6], ray_transforms[g * 9 + 7], ray_transforms[g * 9 + 8] + }; + } + + // wait for other threads to collect the gaussians in batch + block.sync(); + + // process gaussians in the current batch for this pixel + uint32_t batch_size = min(block_size, range_end - batch_start); + for (uint32_t t = 0; (t < batch_size) && !done; ++t) { + + const vec3 xy_opac = xy_opacity_batch[t]; + const S opac = xy_opac.z; + + const vec3 u_M = u_Ms_batch[t]; + const vec3 v_M = v_Ms_batch[t]; + const vec3 w_M = w_Ms_batch[t]; + + const vec3 h_u = px * w_M - u_M; + const vec3 h_v = py * w_M - v_M; + + const vec3 ray_cross = glm::cross(h_u, h_v); + if (ray_cross.z == 0.0) + continue; + + const vec2 s = + vec2(ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z); + + const S gauss_weight_3d = s.x * s.x + s.y * s.y; + const vec2 d = {xy_opac.x - px, xy_opac.y - py}; + const S gauss_weight_2d = + FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y); + const S gauss_weight = min(gauss_weight_3d, gauss_weight_2d); + + const S sigma = 0.5f * gauss_weight; + S alpha = min(0.999f, opac * __expf(-sigma)); + if (sigma < 0.f || alpha < 1.f / 255.f) { + continue; + } + + const S next_T = T * (1.0f - alpha); + if (next_T <= 1e-4) { // this pixel is done: exclusive + done = true; + break; + } + + int32_t g = id_batch[t]; + const S vis = alpha * T; + const S *c_ptr = colors + g * COLOR_DIM; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + pix_out[k] += c_ptr[k] * vis; + } + + const S *n_ptr = normals + g * 3; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + normal_out[k] += n_ptr[k] * vis; + } + + if (render_distort != nullptr) { + // the last channel of colors is depth + const S depth = c_ptr[COLOR_DIM - 1]; + // in nerfacc, loss_bi_0 = weights * t_mids * + // exclusive_sum(weights) + const S distort_bi_0 = vis * depth * (1.0f - T); + // in nerfacc, loss_bi_1 = weights * exclusive_sum(weights * + // t_mids) + const S distort_bi_1 = vis * accum_vis_depth; + distort += 2.0f * (distort_bi_0 - distort_bi_1); + accum_vis_depth += vis * depth; + } + + // compute median depth + if (T > 0.5) { + median_depth = c_ptr[COLOR_DIM - 1]; + median_idx = batch_start + t; + } + + cur_idx = batch_start + t; + + T = next_T; + } + } + if (inside) { + // Here T is the transmittance AFTER the last gaussian in this pixel. + // We (should) store double precision as T would be used in backward + // pass and it can be very small and causing large diff in gradients + // with float32. However, double precision makes the backward pass 1.5x + // slower so we stick with float for now. + render_alphas[pix_id] = 1.0f - T; + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < COLOR_DIM; ++k) { + render_colors[pix_id * COLOR_DIM + k] = + backgrounds == nullptr ? pix_out[k] + : (pix_out[k] + T * backgrounds[k]); + } + GSPLAT_PRAGMA_UNROLL + for (uint32_t k = 0; k < 3; ++k) { + render_normals[pix_id * 3 + k] = normal_out[k]; + } + // index in bin of last gaussian in this pixel + last_ids[pix_id] = static_cast(cur_idx); + + if (render_distort != nullptr) { + render_distort[pix_id] = distort; + } + + render_median[pix_id] = median_depth; + // index in bin of gaussian that contributes to median depth + median_ids[pix_id] = static_cast(median_idx); + } +} + +template +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +call_kernel_with_dim( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +) { + GSPLAT_DEVICE_GUARD(means2d); + GSPLAT_CHECK_INPUT(means2d); + GSPLAT_CHECK_INPUT(ray_transforms); + GSPLAT_CHECK_INPUT(colors); + GSPLAT_CHECK_INPUT(opacities); + GSPLAT_CHECK_INPUT(normals); + GSPLAT_CHECK_INPUT(tile_offsets); + GSPLAT_CHECK_INPUT(flatten_ids); + if (backgrounds.has_value()) { + GSPLAT_CHECK_INPUT(backgrounds.value()); + } + if (masks.has_value()) { + GSPLAT_CHECK_INPUT(masks.value()); + } + bool packed = means2d.dim() == 2; + + uint32_t C = tile_offsets.size(0); // number of cameras + uint32_t N = packed ? 0 : means2d.size(1); // number of gaussians + uint32_t channels = colors.size(-1); + uint32_t tile_height = tile_offsets.size(1); + uint32_t tile_width = tile_offsets.size(2); + uint32_t n_isects = flatten_ids.size(0); + + // Each block covers a tile on the image. In total there are + // C * tile_height * tile_width blocks. + dim3 threads = {tile_size, tile_size, 1}; + dim3 blocks = {C, tile_height, tile_width}; + + torch::Tensor renders = torch::empty( + {C, image_height, image_width, channels}, + means2d.options().dtype(torch::kFloat32) + ); + torch::Tensor alphas = torch::empty( + {C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32) + ); + torch::Tensor last_ids = torch::empty( + {C, image_height, image_width}, means2d.options().dtype(torch::kInt32) + ); + torch::Tensor median_ids = torch::empty( + {C, image_height, image_width}, means2d.options().dtype(torch::kInt32) + ); + + torch::Tensor render_normals = torch::empty( + {C, image_height, image_width, 3}, + means2d.options().dtype(torch::kFloat32) + ); + torch::Tensor render_distort = torch::empty( + {C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32) + ); + torch::Tensor render_median = torch::empty( + {C, image_height, image_width, 1}, + means2d.options().dtype(torch::kFloat32) + ); + + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + const uint32_t shared_mem = + tile_size * tile_size * + (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) + + sizeof(vec3) + sizeof(vec3)); + + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + if (cudaFuncSetAttribute( + rasterize_to_pixels_fwd_2dgs_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem + ) != cudaSuccess) { + AT_ERROR( + "Failed to set maximum shared memory size (requested ", + shared_mem, + " bytes), try lowering tile_size." + ); + } + rasterize_to_pixels_fwd_2dgs_kernel + <<>>( + C, + N, + n_isects, + packed, + reinterpret_cast *>(means2d.data_ptr()), + ray_transforms.data_ptr(), + colors.data_ptr(), + opacities.data_ptr(), + normals.data_ptr(), + backgrounds.has_value() ? backgrounds.value().data_ptr() + : nullptr, + masks.has_value() ? masks.value().data_ptr() : nullptr, + image_width, + image_height, + tile_size, + tile_width, + tile_height, + tile_offsets.data_ptr(), + flatten_ids.data_ptr(), + renders.data_ptr(), + alphas.data_ptr(), + render_normals.data_ptr(), + render_distort.data_ptr(), + render_median.data_ptr(), + last_ids.data_ptr(), + median_ids.data_ptr() + ); + + return std::make_tuple( + renders, + alphas, + render_normals, + render_distort, + render_median, + last_ids, + median_ids + ); +} + +std::tuple< + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor> +rasterize_to_pixels_fwd_2dgs_tensor( + // Gaussian parameters + const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2] + const torch::Tensor &ray_transforms, // [C, N, 3] or [nnz, 3] + const torch::Tensor &colors, // [C, N, channels] or [nnz, channels] + const torch::Tensor &opacities, // [C, N] or [nnz] + const torch::Tensor &normals, // [C, N, 3] or [nnz, 3] + const at::optional &backgrounds, // [C, channels] + const at::optional &masks, // [C, tile_height, tile_width] + // image size + const uint32_t image_width, + const uint32_t image_height, + const uint32_t tile_size, + // intersections + const torch::Tensor &tile_offsets, // [C, tile_height, tile_width] + const torch::Tensor &flatten_ids // [n_isects] +) { + GSPLAT_CHECK_INPUT(colors); + uint32_t channels = colors.size(-1); + +#define __GS__CALL_(N) \ + case N: \ + return call_kernel_with_dim( \ + means2d, \ + ray_transforms, \ + colors, \ + opacities, \ + normals, \ + backgrounds, \ + masks, \ + image_width, \ + image_height, \ + tile_size, \ + tile_offsets, \ + flatten_ids \ + ); + // TODO: an optimization can be done by passing the actual number of + // channels into the kernel functions and avoid necessary global memory + // writes. This requires moving the channel padding from python to C side. + switch (channels) { + __GS__CALL_(1) + __GS__CALL_(2) + __GS__CALL_(3) + __GS__CALL_(4) + __GS__CALL_(5) + __GS__CALL_(8) + __GS__CALL_(9) + __GS__CALL_(16) + __GS__CALL_(17) + __GS__CALL_(32) + __GS__CALL_(33) + __GS__CALL_(64) + __GS__CALL_(65) + __GS__CALL_(128) + __GS__CALL_(129) + __GS__CALL_(256) + __GS__CALL_(257) + __GS__CALL_(512) + __GS__CALL_(513) + default: + AT_ERROR("Unsupported number of channels: ", channels); + } +} + +} // namespace gsplat \ No newline at end of file diff --git a/gsplat/cuda/csrc/utils.cuh b/gsplat/cuda/csrc/utils.cuh index a8dcc3012..49e31fa11 100644 --- a/gsplat/cuda/csrc/utils.cuh +++ b/gsplat/cuda/csrc/utils.cuh @@ -6,6 +6,8 @@ #include #include +#define FILTER_INV_SQUARE 2.0f + namespace gsplat { template @@ -244,11 +246,7 @@ inline __device__ void ortho_proj_vjp( // df/dx = fx * df/dpixx // df/dy = fy * df/dpixy // df/dz = 0 - v_mean3d += vec3( - fx * v_mean2d[0], - fy * v_mean2d[1], - 0.f - ); + v_mean3d += vec3(fx * v_mean2d[0], fy * v_mean2d[1], 0.f); } template @@ -503,6 +501,71 @@ inline __device__ void add_blur_vjp( eps2d * det_conic_blur); } +template +inline __device__ void compute_ray_transforms_aabb_vjp( + const T *ray_transforms, + const T *v_means2d, + const vec3 v_normals, + const mat3 W, + const mat3 P, + const vec3 cam_pos, + const vec3 mean_c, + const vec4 quat, + const vec2 scale, + mat3 &_v_ray_transforms, + vec4 &v_quat, + vec2 &v_scale, + vec3 &v_mean +) { + if (v_means2d[0] != 0 || v_means2d[1] != 0) { + const T distance = ray_transforms[6] * ray_transforms[6] + ray_transforms[7] * ray_transforms[7] - + ray_transforms[8] * ray_transforms[8]; + const T f = 1 / (distance); + const T dpx_dT00 = f * ray_transforms[6]; + const T dpx_dT01 = f * ray_transforms[7]; + const T dpx_dT02 = -f * ray_transforms[8]; + const T dpy_dT10 = f * ray_transforms[6]; + const T dpy_dT11 = f * ray_transforms[7]; + const T dpy_dT12 = -f * ray_transforms[8]; + const T dpx_dT30 = ray_transforms[0] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); + const T dpx_dT31 = ray_transforms[1] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); + const T dpx_dT32 = -ray_transforms[2] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); + const T dpy_dT30 = ray_transforms[3] * (f - 2 * f * f * ray_transforms[6] * ray_transforms[6]); + const T dpy_dT31 = ray_transforms[4] * (f - 2 * f * f * ray_transforms[7] * ray_transforms[7]); + const T dpy_dT32 = -ray_transforms[5] * (f + 2 * f * f * ray_transforms[8] * ray_transforms[8]); + + _v_ray_transforms[0][0] += v_means2d[0] * dpx_dT00; + _v_ray_transforms[0][1] += v_means2d[0] * dpx_dT01; + _v_ray_transforms[0][2] += v_means2d[0] * dpx_dT02; + _v_ray_transforms[1][0] += v_means2d[1] * dpy_dT10; + _v_ray_transforms[1][1] += v_means2d[1] * dpy_dT11; + _v_ray_transforms[1][2] += v_means2d[1] * dpy_dT12; + _v_ray_transforms[2][0] += v_means2d[0] * dpx_dT30 + v_means2d[1] * dpy_dT30; + _v_ray_transforms[2][1] += v_means2d[0] * dpx_dT31 + v_means2d[1] * dpy_dT31; + _v_ray_transforms[2][2] += v_means2d[0] * dpx_dT32 + v_means2d[1] * dpy_dT32; + } + + mat3 R = quat_to_rotmat(quat); + mat3 v_M = P * glm::transpose(_v_ray_transforms); + mat3 W_t = glm::transpose(W); + mat3 v_RS = W_t * v_M; + vec3 v_tn = W_t * v_normals; + + // dual visible + vec3 tn = W * R[2]; + T cos = glm::dot(-tn, mean_c); + T multiplier = cos > 0 ? 1 : -1; + v_tn *= multiplier; + + mat3 v_R = mat3(v_RS[0] * scale[0], v_RS[1] * scale[1], v_tn); + + quat_to_rotmat_vjp(quat, v_R, v_quat); + v_scale[0] += (T)glm::dot(v_RS[0], R[0]); + v_scale[1] += (T)glm::dot(v_RS[1], R[1]); + + v_mean += v_RS[2]; +} + } // namespace gsplat #endif // GSPLAT_CUDA_UTILS_H diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 8f14dbd28..276caa502 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -3,14 +3,17 @@ import torch import torch.distributed +import torch.nn.functional as F from torch import Tensor from typing_extensions import Literal from .cuda._wrapper import ( fully_fused_projection, + fully_fused_projection_2dgs, isect_offset_encode, isect_tiles, rasterize_to_pixels, + rasterize_to_pixels_2dgs, spherical_harmonics, ) from .distributed import ( @@ -19,6 +22,7 @@ all_to_all_int32, all_to_all_tensor_list, ) +from .utils import depth_to_normal, get_projection_matrix def rasterization( @@ -918,33 +922,443 @@ def rasterization_inria_wrapper( GaussianRasterizer, ) - def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): - tanHalfFovY = math.tan((fovY / 2)) - tanHalfFovX = math.tan((fovX / 2)) + assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" + C = len(viewmats) + device = means.device + channels = colors.shape[-1] + + render_colors = [] + for cid in range(C): + FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) + FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) + tanfovx = math.tan(FoVx * 0.5) + tanfovy = math.tan(FoVy * 0.5) + + world_view_transform = viewmats[cid].transpose(0, 1) + projection_matrix = get_projection_matrix( + znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device + ).transpose(0, 1) + full_proj_transform = ( + world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) + ).squeeze(0) + camera_center = world_view_transform.inverse()[3, :3] - top = tanHalfFovY * znear - bottom = -top - right = tanHalfFovX * znear - left = -right + background = ( + backgrounds[cid] + if backgrounds is not None + else torch.zeros(3, device=device) + ) - P = torch.zeros(4, 4, device=device) + raster_settings = GaussianRasterizationSettings( + image_height=height, + image_width=width, + tanfovx=tanfovx, + tanfovy=tanfovy, + bg=background, + scale_modifier=1.0, + viewmatrix=world_view_transform, + projmatrix=full_proj_transform, + sh_degree=0 if sh_degree is None else sh_degree, + campos=camera_center, + prefiltered=False, + debug=False, + ) - z_sign = 1.0 + rasterizer = GaussianRasterizer(raster_settings=raster_settings) - P[0, 0] = 2.0 * znear / (right - left) - P[1, 1] = 2.0 * znear / (top - bottom) - P[0, 2] = (right + left) / (right - left) - P[1, 2] = (top + bottom) / (top - bottom) - P[3, 2] = z_sign - P[2, 2] = z_sign * zfar / (zfar - znear) - P[2, 3] = -(zfar * znear) / (zfar - znear) - return P + means2D = torch.zeros_like(means, requires_grad=True, device=device) + + render_colors_ = [] + for i in range(0, channels, 3): + _colors = colors[..., i : i + 3] + if _colors.shape[-1] < 3: + pad = torch.zeros( + _colors.shape[0], 3 - _colors.shape[-1], device=device + ) + _colors = torch.cat([_colors, pad], dim=-1) + _render_colors_, radii = rasterizer( + means3D=means, + means2D=means2D, + shs=_colors if colors.dim() == 3 else None, + colors_precomp=_colors if colors.dim() == 2 else None, + opacities=opacities[:, None], + scales=scales, + rotations=quats, + cov3D_precomp=None, + ) + if _colors.shape[-1] < 3: + _render_colors_ = _render_colors_[:, :, : _colors.shape[-1]] + render_colors_.append(_render_colors_) + render_colors_ = torch.cat(render_colors_, dim=-1) + + render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] + + render_colors.append(render_colors_) + render_colors = torch.stack(render_colors, dim=0) + return render_colors, None, {} + + +###### 2DGS ###### +def rasterization_2dgs( + means: Tensor, + quats: Tensor, + scales: Tensor, + opacities: Tensor, + colors: Tensor, + viewmats: Tensor, + Ks: Tensor, + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = False, + tile_size: int = 16, + backgrounds: Optional[Tensor] = None, + render_mode: Literal["RGB", "D", "ED", "RGB+D", "RGB+ED"] = "RGB", + sparse_grad: bool = False, + absgrad: bool = False, + distloss: bool = False, + depth_mode: Literal["expected", "median"] = "expected", +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Dict]: + """Rasterize a set of 2D Gaussians (N) to a batch of image planes (C). + + This function supports a handful of features, similar to the :func:`rasterization` function. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + opacities: The opacities of the Gaussians. [N] + colors: The colors of the Gaussians. [(C,) N, D] or [(C,) N, K, 3] for SH coefficients. + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad (Experimental): If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + distloss: If true, use distortion regularization to get better geometry detail. + depth_mode: render depth mode. Choose from expected depth and median depth. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **render_normals**: The rendered normals. [C, height, width, 3]. + + **surf_normals**: surface normal from depth. [C, height, width, 3] + + **render_distort**: The rendered distortions. [C, height, width, 1]. + L1 version, different from L2 version in 2DGS paper. + + **render_median**: The rendered median depth. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + Examples: + + .. code-block:: python + + >>> # define Gaussians + >>> means = torch.randn((100, 3), device=device) + >>> quats = torch.randn((100, 4), device=device) + >>> scales = torch.rand((100, 3), device=device) * 0.1 + >>> colors = torch.rand((100, 3), device=device) + >>> opacities = torch.rand((100,), device=device) + >>> # define cameras + >>> viewmats = torch.eye(4, device=device)[None, :, :] + >>> Ks = torch.tensor([ + >>> [300., 0., 150.], [0., 300., 100.], [0., 0., 1.]], device=device)[None, :, :] + >>> width, height = 300, 200 + >>> # render + >>> colors, alphas, normals, surf_normals, distort, median_depth, meta = rasterization_2dgs( + >>> means, quats, scales, opacities, colors, viewmats, Ks, width, height + >>> ) + >>> print (colors.shape, alphas.shape) + torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 1]) + >>> print (normals.shape, surf_normals.shape) + torch.Size([1, 200, 300, 3]) torch.Size([1, 200, 300, 3]) + >>> print (distort.shape, median_depth.shape) + torch.Size([1, 200, 300, 1]) torch.Size([1, 200, 300, 1]) + >>> print (meta.keys()) + dict_keys(['camera_ids', 'gaussian_ids', 'radii', 'means2d', 'depths', 'ray_transforms', + 'opacities', 'normals', 'tile_width', 'tile_height', 'tiles_per_gauss', 'isect_ids', + 'flatten_ids', 'isect_offsets', 'width', 'height', 'tile_size', 'n_cameras', 'render_distort', + 'gradient_2dgs']) + + """ + + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + assert opacities.shape == (N,), opacities.shape + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + assert render_mode in ["RGB", "D", "ED", "RGB+D", "RGB+ED"], render_mode + if distloss: + assert render_mode in [ + "D", + "ED", + "RGB+D", + "RGB+ED", + ], f"distloss requires depth rendering, render_mode should be D, ED, RGB+D, RGB+ED, but got {render_mode}" + + if sh_degree is None: + # treat colors as post-activation values + # colors should be in shape [N, D] or (C, N, D) (silently support) + assert (colors.dim() == 2 and colors.shape[0] == N) or ( + colors.dim() == 3 and colors.shape[:2] == (C, N) + ), colors.shape + else: + # treat colors as SH coefficients. Allowing for activating partial SH bands + assert ( + colors.dim() == 3 and colors.shape[0] == N and colors.shape[2] == 3 + ), colors.shape + assert (sh_degree + 1) ** 2 <= colors.shape[1], colors.shape + + # Compute Ray-Splat intersection transformation. + proj_results = fully_fused_projection_2dgs( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d, + near_plane, + far_plane, + radius_clip, + packed, + sparse_grad, + ) + + if packed: + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + ray_transforms, + normals, + ) = proj_results + opacities = opacities[gaussian_ids] + else: + radii, means2d, depths, ray_transforms, normals = proj_results + opacities = opacities.repeat(C, 1) + camera_ids, gaussian_ids = None, None + + densify = torch.zeros_like( + means2d, dtype=means.dtype, requires_grad=True, device="cuda" + ) + # Identify intersecting tiles + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, + radii, + depths, + tile_size, + tile_width, + tile_height, + packed=packed, + n_cameras=C, + camera_ids=camera_ids, + gaussian_ids=gaussian_ids, + ) + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + # TODO: SH also suport N-D. + # Compute the per-view colors + if not ( + colors.dim() == 3 and sh_degree is None + ): # silently support [C, N, D] color. + colors = ( + colors[gaussian_ids] if packed else colors.expand(C, *([-1] * colors.dim())) + ) # [nnz, D] or [C, N, 3] + else: + if packed: + colors = colors[camera_ids, gaussian_ids, :] + if sh_degree is not None: # SH coefficients + camtoworlds = torch.inverse(viewmats) + if packed: + dirs = means[gaussian_ids, :] - camtoworlds[camera_ids, :3, 3] + else: + dirs = means[None, :, :] - camtoworlds[:, None, :3, 3] + colors = spherical_harmonics( + sh_degree, dirs, colors, masks=radii > 0 + ) # [nnz, D] or [C, N, 3] + # make it apple-to-apple with Inria's CUDA Backend. + colors = torch.clamp_min(colors + 0.5, 0.0) + + # Rasterize to pixels + if render_mode in ["RGB+D", "RGB+ED"]: + colors = torch.cat((colors, depths[..., None]), dim=-1) + # backgrounds = torch.cat((backgrounds, torch.zeros((C, 1), device="cuda")), dim=-1) + elif render_mode in ["D", "ED"]: + colors = depths[..., None] + else: # RGB + pass + + ( + render_colors, + render_alphas, + render_normals, + render_distort, + render_median, + ) = rasterize_to_pixels_2dgs( + means2d, + ray_transforms, + colors, + opacities, + normals, + densify, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds, + packed=packed, + absgrad=absgrad, + distloss=distloss, + ) + render_normals_from_depth = None + if render_mode in ["ED", "RGB+ED"]: + # normalize the accumulated depth to get the expected depth + render_colors = torch.cat( + [ + render_colors[..., :-1], + render_colors[..., -1:] / render_alphas.clamp(min=1e-10), + ], + dim=-1, + ) + if render_mode in ["RGB+ED", "RGB+D"]: + # render_depths = render_colors[..., -1:] + if depth_mode == "expected": + depth_for_normal = render_colors[..., -1:] + elif depth_mode == "median": + depth_for_normal = render_median + + render_normals_from_depth = depth_to_normal( + depth_for_normal, torch.linalg.inv(viewmats), Ks + ).squeeze(0) + + meta = { + "camera_ids": camera_ids, + "gaussian_ids": gaussian_ids, + "radii": radii, + "means2d": means2d, + "depths": depths, + "ray_transforms": ray_transforms, + "opacities": opacities, + "normals": normals, + "tile_width": tile_width, + "tile_height": tile_height, + "tiles_per_gauss": tiles_per_gauss, + "isect_ids": isect_ids, + "flatten_ids": flatten_ids, + "isect_offsets": isect_offsets, + "width": width, + "height": height, + "tile_size": tile_size, + "n_cameras": C, + "render_distort": render_distort, + "gradient_2dgs": densify, # This holds the gradient used for densification for 2dgs + } + + render_normals = render_normals @ torch.linalg.inv(viewmats)[0, :3, :3].T + + return ( + render_colors, + render_alphas, + render_normals, + render_normals_from_depth, + render_distort, + render_median, + meta, + ) + + +def rasterization_2dgs_inria_wrapper( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + opacities: Tensor, # [N] + colors: Tensor, # [N, D] or [N, K, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 100.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + backgrounds: Optional[Tensor] = None, + depth_ratio: int = 0, + **kwargs, +) -> Tuple[Tuple, Dict]: + """Wrapper for 2DGS's rasterization backend which is based on Inria's backend. + + Install the 2DGS rasterization backend from + https://github.com/hbb1/diff-surfel-rasterization + + Credit to Jeffrey Hu https://github.com/jefequien + + """ + from diff_surfel_rasterization import ( + GaussianRasterizationSettings, + GaussianRasterizer, + ) assert eps2d == 0.3, "This is hard-coded in CUDA to be 0.3" C = len(viewmats) device = means.device channels = colors.shape[-1] + # rasterization from inria does not do normalization internally + quats = F.normalize(quats, dim=-1) # [N, 4] + scales = scales[:, :2] # [N, 2] + render_colors = [] for cid in range(C): FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) @@ -953,7 +1367,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): tanfovy = math.tan(FoVy * 0.5) world_view_transform = viewmats[cid].transpose(0, 1) - projection_matrix = _getProjectionMatrix( + projection_matrix = get_projection_matrix( znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=device ).transpose(0, 1) full_proj_transform = ( @@ -994,7 +1408,7 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): _colors.shape[0], 3 - _colors.shape[-1], device=device ) _colors = torch.cat([_colors, pad], dim=-1) - _render_colors_, radii = rasterizer( + _render_colors_, radii, allmap = rasterizer( means3D=means, means2D=means2D, shs=_colors if colors.dim() == 3 else None, @@ -1010,7 +1424,43 @@ def _getProjectionMatrix(znear, zfar, fovX, fovY, device="cuda"): render_colors_ = torch.cat(render_colors_, dim=-1) render_colors_ = render_colors_.permute(1, 2, 0) # [H, W, 3] - render_colors.append(render_colors_) render_colors = torch.stack(render_colors, dim=0) - return render_colors, None, {} + + # additional maps + allmap = allmap.permute(1, 2, 0).unsqueeze(0) # [1, H, W, C] + render_depth_expected = allmap[..., 0:1] + render_alphas = allmap[..., 1:2] + render_normal = allmap[..., 2:5] + render_depth_median = allmap[..., 5:6] + render_dist = allmap[..., 6:7] + + render_normal = render_normal @ (world_view_transform[:3, :3].T) + render_depth_expected = render_depth_expected / render_alphas + render_depth_expected = torch.nan_to_num(render_depth_expected, 0, 0) + render_depth_median = torch.nan_to_num(render_depth_median, 0, 0) + + # render_depth is either median or expected by setting depth_ratio to 1 or 0 + # for bounded scene, use median depth, i.e., depth_ratio = 1; + # for unbounded scene, use expected depth, i.e., depth_ratio = 0, to reduce disk aliasing. + render_depth = ( + render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median + ) + + normals_surf = depth_to_normal(render_depth, viewmats, Ks) + normals_surf = normals_surf * (render_alphas).detach() + + render_colors = torch.cat([render_colors, render_depth], dim=-1) + + meta = { + "normals_rend": render_normal, + "normals_surf": normals_surf, + "render_distloss": render_dist, + "means2d": means2D, + "width": width, + "height": height, + "radii": radii.unsqueeze(0), + "n_cameras": C, + "gaussian_ids": None, + } + return (render_colors, render_alphas), meta diff --git a/gsplat/strategy/default.py b/gsplat/strategy/default.py index 30ffae6cd..30c152d99 100644 --- a/gsplat/strategy/default.py +++ b/gsplat/strategy/default.py @@ -5,6 +5,7 @@ from .base import Strategy from .ops import duplicate, remove, reset_opa, split +from typing_extensions import Literal @dataclass @@ -54,6 +55,9 @@ class DefaultStrategy(Strategy): revised_opacity (bool): Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental). Default is False. verbose (bool): Whether to print verbose information. Default is False. + key_for_gradient (str): Which variable uses for densification strategy. + 3DGS uses "means2d" gradient and 2DGS uses a similar gradient which stores + in variable "gradient_2dgs". Examples: @@ -87,6 +91,7 @@ class DefaultStrategy(Strategy): absgrad: bool = False revised_opacity: bool = False verbose: bool = False + key_for_gradient: Literal["means2d", "gradient_2dgs"] = "means2d" def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: """Initialize and return the running state for this strategy. @@ -140,9 +145,9 @@ def step_pre_backward( ): """Callback function to be executed before the `loss.backward()` call.""" assert ( - "means2d" in info + self.key_for_gradient in info ), "The 2D means of the Gaussians is required but missing." - info["means2d"].retain_grad() + info[self.key_for_gradient].retain_grad() def step_post_backward( self, @@ -202,19 +207,27 @@ def _update_state( info: Dict[str, Any], packed: bool = False, ): - for key in ["means2d", "width", "height", "n_cameras", "radii", "gaussian_ids"]: + for key in [ + "width", + "height", + "n_cameras", + "radii", + "gaussian_ids", + self.key_for_gradient, + ]: assert key in info, f"{key} is required but missing." # normalize grads to [-1, 1] screen space if self.absgrad: - grads = info["means2d"].absgrad.clone() + grads = info[self.key_for_gradient].absgrad.clone() else: - grads = info["means2d"].grad.clone() + grads = info[self.key_for_gradient].grad.clone() grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] # initialize state on the first run n_gaussian = len(list(params.values())[0]) + if state["grad2d"] is None: state["grad2d"] = torch.zeros(n_gaussian, device=grads.device) if state["count"] is None: diff --git a/gsplat/utils.py b/gsplat/utils.py index 5663fbd50..4ff58e826 100644 --- a/gsplat/utils.py +++ b/gsplat/utils.py @@ -1,4 +1,7 @@ +import math + import torch +import torch.nn.functional as F from torch import Tensor @@ -36,3 +39,227 @@ def log_transform(x): def inverse_log_transform(y): return torch.sign(y) * (torch.expm1(torch.abs(y))) + + +def depth_to_points( + depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True +) -> Tensor: + """Convert depth maps to 3D points + + Args: + depths: Depth maps [..., H, W, 1] + camtoworlds: Camera-to-world transformation matrices [..., 4, 4] + Ks: Camera intrinsics [..., 3, 3] + z_depth: Whether the depth is in z-depth (True) or ray depth (False) + + Returns: + points: 3D points in the world coordinate system [..., H, W, 3] + """ + assert depths.shape[-1] == 1, f"Invalid depth shape: {depths.shape}" + assert camtoworlds.shape[-2:] == ( + 4, + 4, + ), f"Invalid viewmats shape: {camtoworlds.shape}" + assert Ks.shape[-2:] == (3, 3), f"Invalid Ks shape: {Ks.shape}" + assert ( + depths.shape[:-3] == camtoworlds.shape[:-2] == Ks.shape[:-2] + ), f"Shape mismatch! depths: {depths.shape}, viewmats: {camtoworlds.shape}, Ks: {Ks.shape}" + + device = depths.device + height, width = depths.shape[-3:-1] + + x, y = torch.meshgrid( + torch.arange(width, device=device), + torch.arange(height, device=device), + indexing="xy", + ) # [H, W] + + fx = Ks[..., 0, 0] # [...] + fy = Ks[..., 1, 1] # [...] + cx = Ks[..., 0, 2] # [...] + cy = Ks[..., 1, 2] # [...] + + # camera directions in camera coordinates + camera_dirs = F.pad( + torch.stack( + [ + (x - cx[..., None, None] + 0.5) / fx[..., None, None], + (y - cy[..., None, None] + 0.5) / fy[..., None, None], + ], + dim=-1, + ), + (0, 1), + value=1.0, + ) # [..., H, W, 3] + + # ray directions in world coordinates + directions = torch.einsum( + "...ij,...hwj->...hwi", camtoworlds[..., :3, :3], camera_dirs + ) # [..., H, W, 3] + origins = camtoworlds[..., :3, -1] # [..., 3] + + if not z_depth: + directions = F.normalize(directions, dim=-1) + + points = origins[..., None, None, :] + depths * directions + return points + + +def depth_to_normal( + depths: Tensor, camtoworlds: Tensor, Ks: Tensor, z_depth: bool = True +) -> Tensor: + """Convert depth maps to surface normals + + Args: + depths: Depth maps [..., H, W, 1] + camtoworlds: Camera-to-world transformation matrices [..., 4, 4] + Ks: Camera intrinsics [..., 3, 3] + z_depth: Whether the depth is in z-depth (True) or ray depth (False) + + Returns: + normals: Surface normals in the world coordinate system [..., H, W, 3] + """ + points = depth_to_points(depths, camtoworlds, Ks, z_depth=z_depth) # [..., H, W, 3] + dx = torch.cat( + [points[..., 2:, 1:-1, :] - points[..., :-2, 1:-1, :]], dim=-3 + ) # [..., H-2, W-2, 3] + dy = torch.cat( + [points[..., 1:-1, 2:, :] - points[..., 1:-1, :-2, :]], dim=-2 + ) # [..., H-2, W-2, 3] + normals = F.normalize(torch.cross(dx, dy, dim=-1), dim=-1) # [..., H-2, W-2, 3] + normals = F.pad(normals, (0, 0, 1, 1, 1, 1), value=0.0) # [..., H, W, 3] + return normals + + +def get_projection_matrix(znear, zfar, fovX, fovY, device="cuda"): + """Create OpenGL-style projection matrix""" + tanHalfFovY = math.tan((fovY / 2)) + tanHalfFovX = math.tan((fovX / 2)) + + top = tanHalfFovY * znear + bottom = -top + right = tanHalfFovX * znear + left = -right + + P = torch.zeros(4, 4, device=device) + + z_sign = 1.0 + + P[0, 0] = 2.0 * znear / (right - left) + P[1, 1] = 2.0 * znear / (top - bottom) + P[0, 2] = (right + left) / (right - left) + P[1, 2] = (top + bottom) / (top - bottom) + P[3, 2] = z_sign + P[2, 2] = z_sign * zfar / (zfar - znear) + P[2, 3] = -(zfar * znear) / (zfar - znear) + return P + + +# def depth_to_normal( +# depths: Tensor, camtoworlds: Tensor, Ks: Tensor, near_plane: float, far_plane: float +# ) -> Tensor: +# """ +# Convert depth to surface normal + +# Args: +# depths: Z-depth of the Gaussians. +# camtoworlds: camera to world transformation matrix. +# Ks: camera intrinsics. +# near_plane: Near plane distance. +# far_plane: Far plane distance. + +# Returns: +# Surface normals. +# """ +# height, width = depths.shape[1:3] +# viewmats = torch.linalg.inv(camtoworlds) # [C, 4, 4] + +# normals = [] +# for cid, depth in enumerate(depths): +# FoVx = 2 * math.atan(width / (2 * Ks[cid, 0, 0].item())) +# FoVy = 2 * math.atan(height / (2 * Ks[cid, 1, 1].item())) +# world_view_transform = viewmats[cid].transpose(0, 1) +# projection_matrix = _get_projection_matrix( +# znear=near_plane, zfar=far_plane, fovX=FoVx, fovY=FoVy, device=depths.device +# ).transpose(0, 1) +# full_proj_transform = ( +# world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0)) +# ).squeeze(0) +# normal = _depth_to_normal( +# depth, +# world_view_transform, +# full_proj_transform, +# Ks[cid, 0, 0], +# Ks[cid, 1, 1], +# ) +# normals.append(normal) +# normals = torch.stack(normals, dim=0) +# return normals + + +# # ref: https://github.com/hbb1/2d-gaussian-splatting/blob/61c7b417393d5e0c58b742ad5e2e5f9e9f240cc6/utils/point_utils.py#L26 +# def _depths_to_points( +# depthmap, world_view_transform, full_proj_transform, fx, fy +# ) -> Tensor: +# c2w = (world_view_transform.T).inverse() +# H, W = depthmap.shape[:2] + +# intrins = ( +# torch.tensor([[fx, 0.0, W / 2.0], [0.0, fy, H / 2.0], [0.0, 0.0, 1.0]]) +# .float() +# .cuda() +# ) + +# grid_x, grid_y = torch.meshgrid( +# torch.arange(W, device="cuda").float(), +# torch.arange(H, device="cuda").float(), +# indexing="xy", +# ) +# points = torch.stack([grid_x, grid_y, torch.ones_like(grid_x)], dim=-1).reshape( +# -1, 3 +# ) +# rays_d = points @ intrins.inverse().T @ c2w[:3, :3].T +# rays_o = c2w[:3, 3] +# points = depthmap.reshape(-1, 1) * rays_d + rays_o +# return points + + +# def _depth_to_normal( +# depth, world_view_transform, full_proj_transform, fx, fy +# ) -> Tensor: +# points = _depths_to_points( +# depth, +# world_view_transform, +# full_proj_transform, +# fx, +# fy, +# ).reshape(*depth.shape[:2], 3) +# output = torch.zeros_like(points) +# dx = torch.cat([points[2:, 1:-1] - points[:-2, 1:-1]], dim=0) +# dy = torch.cat([points[1:-1, 2:] - points[1:-1, :-2]], dim=1) +# normal_map = torch.nn.functional.normalize(torch.cross(dx, dy, dim=-1), dim=-1) +# output[1:-1, 1:-1, :] = normal_map +# return output + + +# def _get_projection_matrix(znear, zfar, fovX, fovY, device="cuda") -> Tensor: +# tanHalfFovY = math.tan((fovY / 2)) +# tanHalfFovX = math.tan((fovX / 2)) + +# top = tanHalfFovY * znear +# bottom = -top +# right = tanHalfFovX * znear +# left = -right + +# P = torch.zeros(4, 4, device=device) + +# z_sign = 1.0 + +# P[0, 0] = 2.0 * znear / (right - left) +# P[1, 1] = 2.0 * znear / (top - bottom) +# P[0, 2] = (right + left) / (right - left) +# P[1, 2] = (top + bottom) / (top - bottom) +# P[3, 2] = z_sign +# P[2, 2] = z_sign * zfar / (zfar - znear) +# P[2, 3] = -(zfar * znear) / (zfar - znear) +# return P diff --git a/tests/test_2dgs.py b/tests/test_2dgs.py new file mode 100644 index 000000000..9956fb494 --- /dev/null +++ b/tests/test_2dgs.py @@ -0,0 +1,400 @@ +import math + +import pytest +import torch +import pdb + +from gsplat._helper import load_test_data + +device = torch.device("cuda:0") + + +@pytest.fixture +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_data(): + N = 2 + xs = torch.linspace(-1, 1, N, device=device) + ys = torch.linspace(-1, 1, N, device=device) + xys = torch.stack(torch.meshgrid(xs, ys), dim=-1).reshape(-1, 2) + zs = torch.ones_like(xys[:, :1]) * 3 + means = torch.cat([xys, zs], dim=-1) + quats = torch.tensor([[1.0, 0.0, 0.0, 0]], device=device).repeat(len(means), 1) + scales = torch.ones_like(means) + scales[..., :2] *= 0.1 + opacities = torch.ones(1, len(means), device=device) * 0.5 + colors = torch.rand(1, len(means), 3, device=device) + viewmats = torch.eye(4, device=device).reshape(1, 4, 4) + # W, H = 24, 20 + W, H = 640, 480 + fx, fy, cx, cy = W, W, W // 2, H // 2 + Ks = torch.tensor( + [[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], device=device + ).reshape(1, 3, 3) + return { + "means": means, + "quats": quats, + "scales": scales, + "opacities": opacities, + "colors": colors, + "viewmats": viewmats, + "Ks": Ks, + "width": W, + "height": H, + } + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +def test_projection_2dgs(test_data): + from gsplat.cuda._torch_impl_2dgs import _fully_fused_projection_2dgs + from gsplat.cuda._wrapper import fully_fused_projection_2dgs + + torch.manual_seed(42) + + Ks = test_data["Ks"] + viewmats = test_data["viewmats"] + height = test_data["height"] + width = test_data["width"] + quats = test_data["quats"] + scales = test_data["scales"] + means = test_data["means"] + viewmats.requires_grad = True + quats.requires_grad = True + scales.requires_grad = True + means.requires_grad = True + + # forward + _radii, _means2d, _depths, _ray_transforms, _normals = _fully_fused_projection_2dgs( + means, quats, scales, viewmats, Ks, width, height + ) + _ray_transforms = _ray_transforms.permute( + (0, 1, 3, 2) + ) # TODO(WZ): Figure out why do we need to permute here + + radii, means2d, depths, ray_transforms, normals = fully_fused_projection_2dgs( + means, quats, scales, viewmats, Ks, width, height + ) + + # TODO (WZ): is the following true for 2dgs as while? + # radii is integer so we allow for 1 unit difference + valid = (radii > 0) & (_radii > 0) + torch.testing.assert_close(radii, _radii, rtol=0, atol=1) + torch.testing.assert_close(means2d[valid], _means2d[valid], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(depths[valid], _depths[valid], rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + ray_transforms[valid], _ray_transforms[valid], rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(normals[valid], _normals[valid], rtol=1e-4, atol=1e-4) + + # backward + v_means2d = torch.randn_like(means2d) * radii[..., None] + v_depths = torch.randn_like(depths) * radii + v_ray_transforms = torch.randn_like(ray_transforms) * radii[..., None, None] + v_normals = torch.randn_like(normals) * radii[..., None] + + v_quats, v_scales, v_means = torch.autograd.grad( + (means2d * v_means2d).sum() + + (depths * v_depths).sum() + + (ray_transforms * v_ray_transforms).sum() + + (normals * v_normals).sum(), + (quats, scales, means), + ) + _v_quats, _v_scales, _v_means = torch.autograd.grad( + (_means2d * v_means2d).sum() + + (_depths * v_depths).sum() + + (_ray_transforms * v_ray_transforms).sum() + + (_normals * v_normals).sum(), + (quats, scales, means), + ) + + # torch.testing.assert_close(v_viewmats, _v_viewmats, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_quats, _v_quats, rtol=2e-1, atol=1e-2) + torch.testing.assert_close( + v_scales[..., :2], _v_scales[..., :2], rtol=1e-1, atol=2e-1 + ) + torch.testing.assert_close(v_means, _v_means, rtol=1e-2, atol=6e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +@pytest.mark.parametrize( + "sparse_grad", + [ + False, + # True No Sparse-grad for now + ], +) +def test_fully_fused_projection_packed_2dgs( + test_data, + sparse_grad: bool, +): + from gsplat.cuda._wrapper import fully_fused_projection_2dgs + + torch.manual_seed(42) + + Ks = test_data["Ks"] + viewmats = test_data["viewmats"] + height = test_data["height"] + width = test_data["width"] + quats = test_data["quats"] + scales = test_data["scales"] + means = test_data["means"] + viewmats.requires_grad = False + quats.requires_grad = True + scales.requires_grad = True + means.requires_grad = True + + ( + camera_ids, + gaussian_ids, + radii, + means2d, + depths, + ray_transforms, + normals, + ) = fully_fused_projection_2dgs( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=True, + sparse_grad=sparse_grad, + ) + + _radii, _means2d, _depths, _ray_transforms, _normals = fully_fused_projection_2dgs( + means, + quats, + scales, + viewmats, + Ks, + width, + height, + packed=False, + ) + # recover packed tensors to full matrices for testing + __radii = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), radii, _radii.shape + ).to_dense() + __means2d = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), means2d, _means2d.shape + ).to_dense() + __depths = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), depths, _depths.shape + ).to_dense() + __ray_transforms = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), ray_transforms, _ray_transforms.shape + ).to_dense() + __normals = torch.sparse_coo_tensor( + torch.stack([camera_ids, gaussian_ids]), normals, _normals.shape + ).to_dense() + + sel = (__radii > 0) & (_radii > 0) + torch.testing.assert_close(__radii[sel], _radii[sel], rtol=0, atol=1) + torch.testing.assert_close(__means2d[sel], _means2d[sel], rtol=1e-4, atol=1e-4) + torch.testing.assert_close(__depths[sel], _depths[sel], rtol=1e-4, atol=1e-4) + torch.testing.assert_close( + __ray_transforms[sel], _ray_transforms[sel], rtol=1e-4, atol=1e-4 + ) + torch.testing.assert_close(__normals[sel], _normals[sel], rtol=1e-4, atol=1e-4) + + # backward + v_means2d = torch.randn_like(_means2d) * sel[..., None] + v_depths = torch.randn_like(_depths) * sel + v_ray_transforms = torch.randn_like(_ray_transforms) * sel[..., None, None] + v_normals = torch.randn_like(_normals) * sel[..., None] + _v_quats, _v_scales, _v_means = torch.autograd.grad( + (_means2d * v_means2d).sum() + + (_depths * v_depths).sum() + + (_ray_transforms * v_ray_transforms).sum() + + (_normals * v_normals).sum(), + (quats, scales, means), + retain_graph=True, + ) + v_quats, v_scales, v_means = torch.autograd.grad( + (means2d * v_means2d[__radii > 0]).sum() + + (depths * v_depths[__radii > 0]).sum() + + (ray_transforms * v_ray_transforms[__radii > 0]).sum() + + (normals * v_normals[__radii > 0]).sum(), + (quats, scales, means), + retain_graph=True, + ) + if sparse_grad: + v_quats = v_quats.to_dense() + v_scales = v_scales.to_dense() + v_means = v_means.to_dense() + + torch.testing.assert_close(v_scales, _v_scales, rtol=5e-2, atol=5e-2) + torch.testing.assert_close(v_means, _v_means, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_quats, _v_quats, rtol=1e-2, atol=1e-2) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") +# @pytest.mark.parametrize("channels", [3, 32, 128]) +def test_rasterize_to_pixels_2dgs(test_data): + from gsplat.cuda._torch_impl_2dgs import _rasterize_to_pixels_2dgs + from gsplat.cuda._wrapper import ( + fully_fused_projection_2dgs, + isect_offset_encode, + isect_tiles, + rasterize_to_pixels_2dgs, + ) + from gsplat.rendering import rasterization_2dgs_inria_wrapper + + Ks = test_data["Ks"] + viewmats = test_data["viewmats"] + height = test_data["height"] + width = test_data["width"] + quats = test_data["quats"] + scales = test_data["scales"] + means = test_data["means"] + colors = test_data["colors"] + opacities = test_data["opacities"] + + N = means.shape[0] + C = viewmats.shape[0] + + radii, means2d, depths, ray_transforms, normals = fully_fused_projection_2dgs( + means, quats, scales, viewmats, Ks, width, height + ) + + colors = torch.concatenate((colors, depths.unsqueeze(-1)), dim=-1) + backgrounds = torch.zeros((C, colors.shape[-1]), device=device) + + # Identify intersecting tiles + tile_size = 16 + tile_width = math.ceil(width / float(tile_size)) + tile_height = math.ceil(height / float(tile_size)) + tiles_per_gauss, isect_ids, flatten_ids = isect_tiles( + means2d, radii, depths, tile_size, tile_width, tile_height + ) + isect_offsets = isect_offset_encode(isect_ids, C, tile_width, tile_height) + + colors = colors.repeat(C, 1, 1) + opacities = opacities.repeat(C, 1) + normals = normals.repeat(C, 1, 1) + densify = torch.zeros_like(means2d, device=means2d.device) + + means2d.requires_grad = True + ray_transforms.requires_grad = True + colors.requires_grad = True + opacities.requires_grad = True + backgrounds.requires_grad = True + normals.requires_grad = True + densify.requires_grad = True + + ( + render_colors, + render_alphas, + render_normals, + _, + render_median, + ) = rasterize_to_pixels_2dgs( + means2d, + ray_transforms, + colors, + opacities, + normals, + densify, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds, + distloss=True, + ) + + # ray_transforms_torch = ray_transforms.transpose(-1, -2).clone() + _render_colors, _render_alphas, _render_normals = _rasterize_to_pixels_2dgs( + means2d, + ray_transforms, + colors, + normals, + opacities, + width, + height, + tile_size, + isect_offsets, + flatten_ids, + backgrounds=backgrounds, + ) + + cuda_render = render_colors[0].detach().cpu() + torch_render = _render_colors[0].detach().cpu() + diff = (cuda_render - torch_render).abs() + if diff.max() > 1e-5: + print(f"DIFF > 1e-5, {diff.max()=} {diff.mean()=}") + import os + import imageio + + os.makedirs("renders", exist_ok=True) + imageio.imwrite("renders/cuda_render.png", (255 * cuda_render).byte()) + imageio.imwrite("renders/torch_render.png", (255 * torch_render).byte()) + imageio.imwrite("renders/diff.png", (255 * diff).byte()) + + v_render_colors = torch.rand_like(render_colors) + v_render_alphas = torch.rand_like(render_alphas) + v_render_normals = torch.rand_like(render_normals) + + ( + v_means2d, + v_ray_transforms, + v_colors, + v_opacities, + v_backgrounds, + v_normals, + ) = torch.autograd.grad( + (render_colors * v_render_colors).sum() + + (render_alphas * v_render_alphas).sum() + + (render_normals * v_render_normals).sum(), + (means2d, ray_transforms, colors, opacities, backgrounds, normals), + ) + + ( + _v_means2d, + _v_ray_transforms, + _v_colors, + _v_opacities, + _v_backgrounds, + _v_normals, + ) = torch.autograd.grad( + (_render_colors * v_render_colors).sum() + + (_render_alphas * v_render_alphas).sum() + + (_render_normals * v_render_normals).sum(), + (means2d, ray_transforms, colors, opacities, backgrounds, normals), + ) + + pairs = { + "v_means2d": (v_means2d, _v_means2d), + "v_ray_transforms": (v_ray_transforms, _v_ray_transforms), + "v_colors": (v_colors, _v_colors), + "v_opacities": (v_opacities, _v_opacities), + "v_backgrounds": (v_backgrounds, _v_backgrounds), + "v_normals": (v_normals, _v_normals), + } + for name, (v, _v) in pairs.items(): + diff = (v - _v).abs() + print(f"{name=} {v.shape} {diff.max()=} {diff.mean()=}") + + # assert close forward + torch.testing.assert_close(render_colors, _render_colors) + torch.testing.assert_close(render_alphas, _render_alphas) + torch.testing.assert_close(render_normals, _render_normals) + + # assert close backward + torch.testing.assert_close(v_means2d, _v_means2d, rtol=1e-3, atol=1e-3) + torch.testing.assert_close( + v_ray_transforms, _v_ray_transforms, rtol=1e-3, atol=1e-3 + ) + torch.testing.assert_close(v_colors, _v_colors, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_opacities, _v_opacities, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(v_backgrounds, _v_backgrounds, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(v_normals, _v_normals, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_projection_2dgs(test_data()) + test_rasterize_to_pixels_2dgs(test_data()) + test_fully_fused_projection_packed_2dgs(test_data()) + print("All tests passed.")