diff --git a/docs/source/apis/rasterization.rst b/docs/source/apis/rasterization.rst
index f5f6d7eb3..0d0831953 100644
--- a/docs/source/apis/rasterization.rst
+++ b/docs/source/apis/rasterization.rst
@@ -1,6 +1,9 @@
Rasterization
===================================
+3DGS
+------
+
.. currentmodule:: gsplat
Given a set of 3D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, covariances
@@ -38,4 +41,36 @@ projection equation:
Where :math:`[W | t]` is the world-to-camera transformation matrix, and :math:`f_{x}, f_{y}`
are the focal lengths of the camera.
-.. autofunction:: rasterization
\ No newline at end of file
+.. autofunction:: rasterization
+
+2DGS
+------
+
+Given a set of 2D gaussians parametrized by means :math:`\mu \in \mathbb{R}^3`, two principal tangent vectors
+embedded as the first two columns of a rotation matrix :math:`R \in \mathbb{R}^{3\times3}`, and a scale matrix :math:`S \in R^{3\times3}`
+representing the scaling along the two principal tangential directions, we first transforms pixels into splats' local tangent frame
+by :math:`(WH)^{-1} \in \mathbb{R}^{4\times4}` and compute weights via ray-splat intersection. Then we follow the sort and rendering similar to 3DGS.
+
+Note that H is the transformation from splat's local tangent plane :math:`\{u, v\}` into world space
+
+.. math::
+
+ H = \begin{bmatrix}
+ RS & \mu \\
+ 0 & 1
+ \end{bmatrix}
+
+and :math:`W \in \mathbb{R}^{4\times4}` is the transformation matrix from world space to image space.
+
+
+Splatting is done via ray-splat plane intersection. Each pixel is considered as a x-plane :math:`h_{x}=(-1, 0, 0, x)^{T}`
+and a y-plane :math:`h_{y}=(0, -1, 0, y)^{T}`, and the intersection between a splat and the pixel :math:`p=(x, y)` is defined
+as the intersection bwtween x-plane, y-plane, and the splat's tangent plane. We first transform :math:`h_{x}` to :math:`h_{u}` and :math:`h_{y}`
+to :math:`h_{v}` in splat's tangent frame via the inverse transformation :math:`(WH)^{-1}`. As the intersection point should fall on :math:`h_{u}` and :math:`h_{v}`, we have an efficient
+solution:
+
+.. math::
+ u(p) = \frac{h^{2}_{u}h^{4}_{v}-h^{4}_{u}h^{2}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}},
+ v(p) = \frac{h^{4}_{u}h^{1}_{v}-h^{1}_{u}h^{4}_{v}}{h^{1}_{u}h^{2}_{v}-h^{2}_{u}h^{1}_{v}}
+
+.. autofunction:: rasterization_2dgs
\ No newline at end of file
diff --git a/docs/source/apis/utils.rst b/docs/source/apis/utils.rst
index cd83393be..cc839dabe 100644
--- a/docs/source/apis/utils.rst
+++ b/docs/source/apis/utils.rst
@@ -4,6 +4,9 @@ Utils
Below are the basic functions that supports the rasterization.
+3DGS
+-----
+
.. currentmodule:: gsplat
.. autofunction:: spherical_harmonics
@@ -27,3 +30,17 @@ Below are the basic functions that supports the rasterization.
.. autofunction:: accumulate
.. autofunction:: rasterization_inria_wrapper
+
+2DGS
+-----
+.. currentmodule:: gsplat
+
+.. autofunction:: fully_fused_projection_2dgs
+
+.. autofunction:: rasterize_to_pixels_2dgs
+
+.. autofunction:: rasterize_to_indices_in_range_2dgs
+
+.. autofunction:: accumulate_2dgs
+
+.. autofunction:: rasterization_2dgs_inria_wrapper
\ No newline at end of file
diff --git a/docs/source/tests/eval.rst b/docs/source/tests/eval.rst
index 9f2a2fdc0..49c13525e 100644
--- a/docs/source/tests/eval.rst
+++ b/docs/source/tests/eval.rst
@@ -1,6 +1,9 @@
Evaluation
===================================
+3DGS
+----------------------------------------------
+
.. table:: Performance on `Mip-NeRF 360 Captures `_ (Averaged Over 7 Scenes)
+---------------------+-------+-------+-------+------------------+------------+
@@ -140,3 +143,100 @@ The evaluation of `inria-X` can be
reproduced with our forked wersion of the official implementation at
`here `_,
with the command :code:`python full_eval_m360.py` (commit 36546ce).
+
+2DGS
+----------------------------------------------
+
+No Regularization
+----------------------------------------------
+
+.. table:: Performance on `Mip-NeRF 360 Captures `_ (Averaged Over 7 Scenes)
+
++---------------------+-------+-------+-------+------------------+------------+
+| | PSNR | SSIM | LPIPS | Train Mem | Train Time |
++=====================+=======+=======+=======+==================+============+
+| inria-30k | 28.73 | 0.860 | 0.148 | 3.73 GB | 22m16s |
++---------------------+-------+-------+-------+------------------+------------+
+| gsplat-30k | 28.76 | 0.867 | 0.145 | **3.70 GB** | **15m44s** |
++---------------------+-------+-------+-------+------------------+------------+
+
+With Normal Consistency and Distortion Regularization
+------------------------------------------------------
+
++---------------------+-------+-------+-------+------------------+------------+
+| | PSNR | SSIM | LPIPS | Train Mem | Train Time |
++=====================+=======+=======+=======+==================+============+
+| inria-30k | 28.05 | 0.848 | 0.186 | 3.76 GB | 22m06s |
++---------------------+-------+-------+-------+------------------+------------+
+| gsplat-30k | 27.80 | 0.842 | 0.169 | **3.61 GB** | **16m44s** |
++---------------------+-------+-------+-------+------------------+------------+
+
+Runtime and GPU Memory
+----------------------------------------------
+
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+| Train Mem (GB) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++=================+=========+========+=========+========+=========+========+========+
+| inria-30k |**6.74** | 2.27 | 2.06 | 4.79 | 2.25 | 2.40 |**5.58**|
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+| gsplat-30k | 6.89 |**2.19**| **1.93**|**4.48**| **2.14**|**2.30**| 6.00 |
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+| Train Time (s) | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++=================+=========+========+=========+========+=========+========+========+
+| inria-30k | 1463 | 1237 | 1318 | 1298 | 1422 | 1314 | 1252 |
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+| gsplat-30k |**1231** |**788** | **803**| **985**| **828**| **789**|**1057**|
++-----------------+---------+--------+---------+--------+---------+--------+--------+
+
+
+Reproduced Metrics
+----------------------------------------------
+
++------------+---------+--------+---------+--------+---------+-------+-------+
+| PSNR | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++============+=========+========+=========+========+=========+=======+=======+
+| inria-30k | 24.92 | 31.87 | 28.78 | 26.88 | 31.08 | 31.21 | 26.36 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+| gsplat-30k | 24.97 | 31.94 | 28.76 | 26.95 | 31.08 | 31.27 | 26.37 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+
++------------+---------+--------+---------+--------+---------+-------+-------+
+| SSIM | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++============+=========+========+=========+========+=========+=======+=======+
+| inria-30k | 0.741 | 0.937 | 0.899 | 0.847 | 0.921 | 0.914 | 0.760 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+| gsplat-30k | 0.764 | 0.937 | 0.899 | 0.849 | 0.921 | 0.915 | 0.761 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+
++------------+---------+--------+---------+--------+---------+-------+-------+
+| LPIPS | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++============+=========+========+=========+========+=========+=======+=======+
+| inria-30k | 0.199 | 0.136 | 0.164 | 0.093 | 0.101 | 0.172 | 0.168 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+| gsplat-30k | 0.189 | 0.134 | 0.162 | 0.091 | 0.101 | 0.169 | 0.166 |
++------------+---------+--------+---------+--------+---------+-------+-------+
+
++-----------------+---------+--------+---------+--------+---------+-------+-------+
+| Number of GSs | Bicycle | Bonsai | Counter | Garden | Kitchen | Room | Stump |
++=================+=========+========+=========+========+=========+=======+=======+
+| inria-30k | 3.97M | 0.91M | 0.72M | 2.79M | 0.85M | 1.01M | 3.27M |
++-----------------+---------+--------+---------+--------+---------+-------+-------+
+| gsplat-30k | 3.88M | 0.92M | 0.73M | 2.49M | 0.87M | 1.03M | 3.40M |
++-----------------+---------+--------+---------+--------+---------+-------+-------+
+
+Note: Evaulations for 2DGS are conducted on a NVIDIA RTX 4090 GPU. The LPIPS metric is evaluated
+using :code:`from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity`, which
+is different from what's reported in the original paper that uses
+:code:`from lpipsPyTorch import lpips`.
+
+The evaluation of `gsplat-X` can be reproduced with the command
+:code:`cd examples; bash benchmarks/basic_2dgs.sh`
+within the gsplat repo (commit 48abf70).
+
+The evaluation of `inria-X` can be
+reproduced with our forked wersion of the official implementation at
+`here `_;
+you need to change the :code:`--model_type 2dgs` to :code:`--model_type 2dgs-inria` in
+:code:`benchmars/basic_2dgs` and run command :code:`cd examples; bash benchmarks/basic_2dgs.sh` (commit 28c928a).
\ No newline at end of file
diff --git a/examples/benchmarks/basic_2dgs.sh b/examples/benchmarks/basic_2dgs.sh
new file mode 100755
index 000000000..04d3d8fd4
--- /dev/null
+++ b/examples/benchmarks/basic_2dgs.sh
@@ -0,0 +1,52 @@
+SCENE_DIR="data/360_v2"
+RESULT_DIR="results/benchmark_2dgs"
+SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers
+
+for SCENE in $SCENE_LIST;
+do
+ if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then
+ DATA_FACTOR=2
+ else
+ DATA_FACTOR=4
+ fi
+
+ echo "Running $SCENE"
+
+ # train without eval
+ CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \
+ --model_type 2dgs \
+ --data_dir data/360_v2/$SCENE/ \
+ --result_dir $RESULT_DIR/$SCENE/
+
+ # run eval and render
+ for CKPT in $RESULT_DIR/$SCENE/ckpts/*;
+ do
+ CUDA_VISIBLE_DEVICES=0 python simple_trainer_2dgs.py --disable_viewer --data_factor $DATA_FACTOR \
+ --model_type 2dgs \
+ --data_dir data/360_v2/$SCENE/ \
+ --result_dir $RESULT_DIR/$SCENE/ \
+ --ckpt $CKPT
+ done
+done
+
+
+for SCENE in $SCENE_LIST;
+do
+ echo "=== Eval Stats ==="
+
+ for STATS in $RESULT_DIR/$SCENE/stats/val*.json;
+ do
+ echo $STATS
+ cat $STATS;
+ echo
+ done
+
+ echo "=== Train Stats ==="
+
+ for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json;
+ do
+ echo $STATS
+ cat $STATS;
+ echo
+ done
+done
\ No newline at end of file
diff --git a/examples/image_fitting.py b/examples/image_fitting.py
index 672f7a1aa..434b7869b 100644
--- a/examples/image_fitting.py
+++ b/examples/image_fitting.py
@@ -2,7 +2,7 @@
import os
import time
from pathlib import Path
-from typing import Optional
+from typing import Literal, Optional
import numpy as np
import torch
@@ -10,7 +10,7 @@
from PIL import Image
from torch import Tensor, optim
-from gsplat import rasterization
+from gsplat import rasterization, rasterization_2dgs
class SimpleTrainer:
@@ -79,6 +79,7 @@ def train(
iterations: int = 1000,
lr: float = 0.01,
save_imgs: bool = False,
+ model_type: Literal["3dgs", "2dgs"] = "3dgs",
):
optimizer = optim.Adam(
[self.rgbs, self.means, self.scales, self.opacities, self.quats], lr
@@ -94,9 +95,16 @@ def train(
],
device=self.device,
)
+
+ if model_type == "3dgs":
+ rasterize_fnc = rasterization
+ elif model_type == "2dgs":
+ rasterize_fnc = rasterization_2dgs
+
for iter in range(iterations):
start = time.time()
- renders, _, _ = rasterization(
+
+ renders = rasterize_fnc(
self.means,
self.quats / self.quats.norm(dim=-1, keepdim=True),
self.scales,
@@ -107,7 +115,7 @@ def train(
self.W,
self.H,
packed=False,
- )
+ )[0]
out_img = renders[0]
torch.cuda.synchronize()
times[0] += time.time() - start
@@ -125,7 +133,7 @@ def train(
if save_imgs:
# save them as a gif with PIL
frames = [Image.fromarray(frame) for frame in frames]
- out_dir = os.path.join(os.getcwd(), "renders")
+ out_dir = os.path.join(os.getcwd(), "results")
os.makedirs(out_dir, exist_ok=True)
frames[0].save(
f"{out_dir}/training.gif",
@@ -158,6 +166,7 @@ def main(
img_path: Optional[Path] = None,
iterations: int = 1000,
lr: float = 0.01,
+ model_type: Literal["3dgs", "2dgs"] = "3dgs",
) -> None:
if img_path:
gt_image = image_path_to_tensor(img_path)
@@ -172,6 +181,7 @@ def main(
iterations=iterations,
lr=lr,
save_imgs=save_imgs,
+ model_type=model_type,
)
diff --git a/examples/requirements.txt b/examples/requirements.txt
index 6ebe9abeb..148b61135 100644
--- a/examples/requirements.txt
+++ b/examples/requirements.txt
@@ -17,4 +17,5 @@ tyro>=0.8.8
Pillow
tensorboard
pyyaml
+matplotlib
git+https://github.com/rahul-goel/fused-ssim@84422e0da94c516220eb3acedb907e68809e9e01
diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py
index 53c751575..a1bf7c052 100644
--- a/examples/simple_trainer.py
+++ b/examples/simple_trainer.py
@@ -15,7 +15,7 @@
import viser
import yaml
from datasets.colmap import Dataset, Parser
-from datasets.traj import generate_interpolated_path, generate_ellipse_path_z
+from datasets.traj import generate_ellipse_path_z, generate_interpolated_path
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
diff --git a/examples/simple_trainer_2dgs.py b/examples/simple_trainer_2dgs.py
new file mode 100644
index 000000000..c37191f66
--- /dev/null
+++ b/examples/simple_trainer_2dgs.py
@@ -0,0 +1,970 @@
+import json
+import math
+import os
+import time
+from dataclasses import dataclass, field
+from typing import Dict, List, Literal, Optional, Tuple
+
+import imageio
+import nerfview
+import numpy as np
+import torch
+import torch.nn.functional as F
+import tqdm
+import tyro
+import viser
+from datasets.colmap import Dataset, Parser
+from datasets.traj import generate_interpolated_path
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
+from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
+from utils import (
+ AppearanceOptModule,
+ CameraOptModule,
+ apply_depth_colormap,
+ colormap,
+ knn,
+ rgb_to_sh,
+ set_random_seed,
+)
+
+from gsplat.rendering import rasterization_2dgs, rasterization_2dgs_inria_wrapper
+from gsplat.strategy import DefaultStrategy
+
+
+@dataclass
+class Config:
+ # Disable viewer
+ disable_viewer: bool = False
+ # Path to the .pt file. If provide, it will skip training and render a video
+ ckpt: Optional[str] = None
+
+ # Path to the Mip-NeRF 360 dataset
+ data_dir: str = "data/360_v2/garden"
+ # Downsample factor for the dataset
+ data_factor: int = 4
+ # Directory to save results
+ result_dir: str = "results/garden"
+ # Every N images there is a test image
+ test_every: int = 8
+ # Random crop size for training (experimental)
+ patch_size: Optional[int] = None
+ # A global scaler that applies to the scene size related parameters
+ global_scale: float = 1.0
+
+ # Port for the viewer server
+ port: int = 8080
+
+ # Batch size for training. Learning rates are scaled automatically
+ batch_size: int = 1
+ # A global factor to scale the number of training steps
+ steps_scaler: float = 1.0
+
+ # Number of training steps
+ max_steps: int = 30_000
+ # Steps to evaluate the model
+ eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+ # Steps to save the model
+ save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
+
+ # Initialization strategy
+ init_type: str = "sfm"
+ # Initial number of GSs. Ignored if using sfm
+ init_num_pts: int = 100_000
+ # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm
+ init_extent: float = 3.0
+ # Degree of spherical harmonics
+ sh_degree: int = 3
+ # Turn on another SH degree every this steps
+ sh_degree_interval: int = 1000
+ # Initial opacity of GS
+ init_opa: float = 0.1
+ # Initial scale of GS
+ init_scale: float = 1.0
+ # Weight for SSIM loss
+ ssim_lambda: float = 0.2
+
+ # Near plane clipping distance
+ near_plane: float = 0.2
+ # Far plane clipping distance
+ far_plane: float = 200
+
+ # GSs with opacity below this value will be pruned
+ prune_opa: float = 0.05
+ # GSs with image plane gradient above this value will be split/duplicated
+ grow_grad2d: float = 0.0002
+ # GSs with scale below this value will be duplicated. Above will be split
+ grow_scale3d: float = 0.01
+ # GSs with scale above this value will be pruned.
+ prune_scale3d: float = 0.1
+
+ # Start refining GSs after this iteration
+ refine_start_iter: int = 500
+ # Stop refining GSs after this iteration
+ refine_stop_iter: int = 15_000
+ # Reset opacities every this steps
+ reset_every: int = 3000
+ # Refine GSs every this steps
+ refine_every: int = 100
+
+ # Use packed mode for rasterization, this leads to less memory usage but slightly slower.
+ packed: bool = False
+ # Use sparse gradients for optimization. (experimental)
+ sparse_grad: bool = False
+ # Use absolute gradient for pruning. This typically requires larger --grow_grad2d, e.g., 0.0008 or 0.0006
+ absgrad: bool = False
+ # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics.
+ antialiased: bool = False
+ # Whether to use revised opacity heuristic from arXiv:2404.06109 (experimental)
+ revised_opacity: bool = False
+
+ # Use random background for training to discourage transparency
+ random_bkgd: bool = False
+
+ # Enable camera optimization.
+ pose_opt: bool = False
+ # Learning rate for camera optimization
+ pose_opt_lr: float = 1e-5
+ # Regularization for camera optimization as weight decay
+ pose_opt_reg: float = 1e-6
+ # Add noise to camera extrinsics. This is only to test the camera pose optimization.
+ pose_noise: float = 0.0
+
+ # Enable appearance optimization. (experimental)
+ app_opt: bool = False
+ # Appearance embedding dimension
+ app_embed_dim: int = 16
+ # Learning rate for appearance optimization
+ app_opt_lr: float = 1e-3
+ # Regularization for appearance optimization as weight decay
+ app_opt_reg: float = 1e-6
+
+ # Enable depth loss. (experimental)
+ depth_loss: bool = False
+ # Weight for depth loss
+ depth_lambda: float = 1e-2
+
+ # Enable normal consistency loss. (Currently for 2DGS only)
+ normal_loss: bool = False
+ # Weight for normal loss
+ normal_lambda: float = 5e-2
+ # Iteration to start normal consistency regulerization
+ normal_start_iter: int = 7_000
+
+ # Distortion loss. (experimental)
+ dist_loss: bool = False
+ # Weight for distortion loss
+ dist_lambda: float = 1e-2
+ # Iteration to start distortion loss regulerization
+ dist_start_iter: int = 3_000
+
+ # Model for splatting.
+ model_type: Literal["2dgs", "2dgs-inria"] = "2dgs"
+
+ # Dump information to tensorboard every this steps
+ tb_every: int = 100
+ # Save training images to tensorboard
+ tb_save_image: bool = False
+
+ def adjust_steps(self, factor: float):
+ self.eval_steps = [int(i * factor) for i in self.eval_steps]
+ self.save_steps = [int(i * factor) for i in self.save_steps]
+ self.max_steps = int(self.max_steps * factor)
+ self.sh_degree_interval = int(self.sh_degree_interval * factor)
+ self.refine_start_iter = int(self.refine_start_iter * factor)
+ self.refine_stop_iter = int(self.refine_stop_iter * factor)
+ self.reset_every = int(self.reset_every * factor)
+ self.refine_every = int(self.refine_every * factor)
+
+
+def create_splats_with_optimizers(
+ parser: Parser,
+ init_type: str = "sfm",
+ init_num_pts: int = 100_000,
+ init_extent: float = 3.0,
+ init_opacity: float = 0.1,
+ init_scale: float = 1.0,
+ scene_scale: float = 1.0,
+ sh_degree: int = 3,
+ sparse_grad: bool = False,
+ batch_size: int = 1,
+ feature_dim: Optional[int] = None,
+ device: str = "cuda",
+) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]:
+ if init_type == "sfm":
+ points = torch.from_numpy(parser.points).float()
+ rgbs = torch.from_numpy(parser.points_rgb / 255.0).float()
+ elif init_type == "random":
+ points = init_extent * scene_scale * (torch.rand((init_num_pts, 3)) * 2 - 1)
+ rgbs = torch.rand((init_num_pts, 3))
+ else:
+ raise ValueError("Please specify a correct init_type: sfm or random")
+
+ N = points.shape[0]
+ # Initialize the GS size to be the average dist of the 3 nearest neighbors
+ dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,]
+ dist_avg = torch.sqrt(dist2_avg)
+ scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 3) # [N, 3]
+ quats = torch.rand((N, 4)) # [N, 4]
+ opacities = torch.logit(torch.full((N,), init_opacity)) # [N,]
+
+ params = [
+ # name, value, lr
+ ("means", torch.nn.Parameter(points), 1.6e-4 * scene_scale),
+ ("scales", torch.nn.Parameter(scales), 5e-3),
+ ("quats", torch.nn.Parameter(quats), 1e-3),
+ ("opacities", torch.nn.Parameter(opacities), 5e-2),
+ ]
+
+ if feature_dim is None:
+ # color is SH coefficients.
+ colors = torch.zeros((N, (sh_degree + 1) ** 2, 3)) # [N, K, 3]
+ colors[:, 0, :] = rgb_to_sh(rgbs)
+ params.append(("sh0", torch.nn.Parameter(colors[:, :1, :]), 2.5e-3))
+ params.append(("shN", torch.nn.Parameter(colors[:, 1:, :]), 2.5e-3 / 20))
+ else:
+ # features will be used for appearance and view-dependent shading
+ features = torch.rand(N, feature_dim) # [N, feature_dim]
+ params.append(("features", torch.nn.Parameter(features), 2.5e-3))
+ colors = torch.logit(rgbs) # [N, 3]
+ params.append(("colors", torch.nn.Parameter(colors), 2.5e-3))
+
+ splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device)
+ # Scale learning rate based on batch size, reference:
+ # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/
+ # Note that this would not make the training exactly equivalent, see
+ # https://arxiv.org/pdf/2402.18824v1
+ optimizers = {
+ name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)(
+ [{"params": splats[name], "lr": lr * math.sqrt(batch_size)}],
+ eps=1e-15 / math.sqrt(batch_size),
+ betas=(1 - batch_size * (1 - 0.9), 1 - batch_size * (1 - 0.999)),
+ )
+ for name, _, lr in params
+ }
+ return splats, optimizers
+
+
+class Runner:
+ """Engine for training and testing."""
+
+ def __init__(self, cfg: Config) -> None:
+ set_random_seed(42)
+
+ self.cfg = cfg
+ self.device = "cuda"
+
+ # Where to dump results.
+ os.makedirs(cfg.result_dir, exist_ok=True)
+
+ # Setup output directories.
+ self.ckpt_dir = f"{cfg.result_dir}/ckpts"
+ os.makedirs(self.ckpt_dir, exist_ok=True)
+ self.stats_dir = f"{cfg.result_dir}/stats"
+ os.makedirs(self.stats_dir, exist_ok=True)
+ self.render_dir = f"{cfg.result_dir}/renders"
+ os.makedirs(self.render_dir, exist_ok=True)
+
+ # Tensorboard
+ self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb")
+
+ # Load data: Training data should contain initial points and colors.
+ self.parser = Parser(
+ data_dir=cfg.data_dir,
+ factor=cfg.data_factor,
+ normalize=True,
+ test_every=cfg.test_every,
+ )
+ self.trainset = Dataset(
+ self.parser,
+ split="train",
+ patch_size=cfg.patch_size,
+ load_depths=cfg.depth_loss,
+ )
+ self.valset = Dataset(self.parser, split="val")
+ self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale
+ print("Scene scale:", self.scene_scale)
+
+ # Model
+ feature_dim = 32 if cfg.app_opt else None
+ self.splats, self.optimizers = create_splats_with_optimizers(
+ self.parser,
+ init_type=cfg.init_type,
+ init_num_pts=cfg.init_num_pts,
+ init_extent=cfg.init_extent,
+ init_opacity=cfg.init_opa,
+ init_scale=cfg.init_scale,
+ scene_scale=self.scene_scale,
+ sh_degree=cfg.sh_degree,
+ sparse_grad=cfg.sparse_grad,
+ batch_size=cfg.batch_size,
+ feature_dim=feature_dim,
+ device=self.device,
+ )
+ print("Model initialized. Number of GS:", len(self.splats["means"]))
+ self.model_type = cfg.model_type
+
+ if self.model_type == "2dgs":
+ key_for_gradient = "gradient_2dgs"
+ else:
+ key_for_gradient = "means2d"
+
+ # Densification Strategy
+ self.strategy = DefaultStrategy(
+ verbose=True,
+ prune_opa=cfg.prune_opa,
+ grow_grad2d=cfg.grow_grad2d,
+ grow_scale3d=cfg.grow_scale3d,
+ prune_scale3d=cfg.prune_scale3d,
+ # refine_scale2d_stop_iter=4000, # splatfacto behavior
+ refine_start_iter=cfg.refine_start_iter,
+ refine_stop_iter=cfg.refine_stop_iter,
+ reset_every=cfg.reset_every,
+ refine_every=cfg.refine_every,
+ absgrad=cfg.absgrad,
+ revised_opacity=cfg.revised_opacity,
+ key_for_gradient=key_for_gradient,
+ )
+ self.strategy.check_sanity(self.splats, self.optimizers)
+ self.strategy_state = self.strategy.initialize_state()
+
+ self.pose_optimizers = []
+ if cfg.pose_opt:
+ self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_adjust.zero_init()
+ self.pose_optimizers = [
+ torch.optim.Adam(
+ self.pose_adjust.parameters(),
+ lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size),
+ weight_decay=cfg.pose_opt_reg,
+ )
+ ]
+
+ if cfg.pose_noise > 0.0:
+ self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device)
+ self.pose_perturb.random_init(cfg.pose_noise)
+
+ self.app_optimizers = []
+ if cfg.app_opt:
+ self.app_module = AppearanceOptModule(
+ len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree
+ ).to(self.device)
+ # initialize the last layer to be zero so that the initial output is zero.
+ torch.nn.init.zeros_(self.app_module.color_head[-1].weight)
+ torch.nn.init.zeros_(self.app_module.color_head[-1].bias)
+ self.app_optimizers = [
+ torch.optim.Adam(
+ self.app_module.embeds.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0,
+ weight_decay=cfg.app_opt_reg,
+ ),
+ torch.optim.Adam(
+ self.app_module.color_head.parameters(),
+ lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size),
+ ),
+ ]
+
+ # Losses & Metrics.
+ self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device)
+ self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device)
+ self.lpips = LearnedPerceptualImagePatchSimilarity(normalize=True).to(
+ self.device
+ )
+
+ # Viewer
+ if not self.cfg.disable_viewer:
+ self.server = viser.ViserServer(port=cfg.port, verbose=False)
+ self.viewer = nerfview.Viewer(
+ server=self.server,
+ render_fn=self._viewer_render_fn,
+ mode="training",
+ )
+
+ def rasterize_splats(
+ self,
+ camtoworlds: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Dict]:
+ means = self.splats["means"] # [N, 3]
+ # quats = F.normalize(self.splats["quats"], dim=-1) # [N, 4]
+ # rasterization does normalization internally
+ quats = self.splats["quats"] # [N, 4]
+ scales = torch.exp(self.splats["scales"]) # [N, 3]
+ opacities = torch.sigmoid(self.splats["opacities"]) # [N,]
+
+ image_ids = kwargs.pop("image_ids", None)
+ if self.cfg.app_opt:
+ colors = self.app_module(
+ features=self.splats["features"],
+ embed_ids=image_ids,
+ dirs=means[None, :, :] - camtoworlds[:, None, :3, 3],
+ sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree),
+ )
+ colors = colors + self.splats["colors"]
+ colors = torch.sigmoid(colors)
+ else:
+ colors = torch.cat([self.splats["sh0"], self.splats["shN"]], 1) # [N, K, 3]
+
+ assert self.cfg.antialiased is False, "Antialiased is not supported for 2DGS"
+
+ if self.model_type == "2dgs":
+ (
+ render_colors,
+ render_alphas,
+ render_normals,
+ normals_from_depth,
+ render_distort,
+ render_median,
+ info,
+ ) = rasterization_2dgs(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=self.cfg.absgrad,
+ sparse_grad=self.cfg.sparse_grad,
+ **kwargs,
+ )
+ elif self.model_type == "2dgs-inria":
+ render_colors, render_alphas, info = rasterization_2dgs_inria_wrapper(
+ means=means,
+ quats=quats,
+ scales=scales,
+ opacities=opacities,
+ colors=colors,
+ viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4]
+ Ks=Ks, # [C, 3, 3]
+ width=width,
+ height=height,
+ packed=self.cfg.packed,
+ absgrad=self.cfg.absgrad,
+ sparse_grad=self.cfg.sparse_grad,
+ **kwargs,
+ )
+ render_colors, render_alphas = renders
+ render_normals = info["normals_rend"]
+ normals_from_depth = info["normals_surf"]
+ render_distort = info["render_distloss"]
+ render_median = render_colors[..., 3]
+
+ return (
+ render_colors,
+ render_alphas,
+ render_normals,
+ normals_from_depth,
+ render_distort,
+ render_median,
+ info,
+ )
+
+ def train(self):
+ cfg = self.cfg
+ device = self.device
+
+ # Dump cfg.
+ with open(f"{cfg.result_dir}/cfg.json", "w") as f:
+ json.dump(vars(cfg), f)
+
+ max_steps = cfg.max_steps
+ init_step = 0
+
+ schedulers = [
+ # means has a learning rate schedule, that end at 0.01 of the initial value
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.optimizers["means"], gamma=0.01 ** (1.0 / max_steps)
+ ),
+ ]
+ if cfg.pose_opt:
+ # pose optimization has a learning rate schedule
+ schedulers.append(
+ torch.optim.lr_scheduler.ExponentialLR(
+ self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps)
+ )
+ )
+
+ trainloader = torch.utils.data.DataLoader(
+ self.trainset,
+ batch_size=cfg.batch_size,
+ shuffle=True,
+ num_workers=4,
+ persistent_workers=True,
+ pin_memory=True,
+ )
+ trainloader_iter = iter(trainloader)
+
+ # Training loop.
+ global_tic = time.time()
+ pbar = tqdm.tqdm(range(init_step, max_steps))
+ for step in pbar:
+ if not cfg.disable_viewer:
+ while self.viewer.state.status == "paused":
+ time.sleep(0.01)
+ self.viewer.lock.acquire()
+ tic = time.time()
+
+ try:
+ data = next(trainloader_iter)
+ except StopIteration:
+ trainloader_iter = iter(trainloader)
+ data = next(trainloader_iter)
+
+ camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4]
+ Ks = data["K"].to(device) # [1, 3, 3]
+ pixels = data["image"].to(device) / 255.0 # [1, H, W, 3]
+ num_train_rays_per_step = (
+ pixels.shape[0] * pixels.shape[1] * pixels.shape[2]
+ )
+ image_ids = data["image_id"].to(device)
+ if cfg.depth_loss:
+ points = data["points"].to(device) # [1, M, 2]
+ depths_gt = data["depths"].to(device) # [1, M]
+
+ height, width = pixels.shape[1:3]
+
+ if cfg.pose_noise:
+ camtoworlds = self.pose_perturb(camtoworlds, image_ids)
+
+ if cfg.pose_opt:
+ camtoworlds = self.pose_adjust(camtoworlds, image_ids)
+
+ # sh schedule
+ sh_degree_to_use = min(step // cfg.sh_degree_interval, cfg.sh_degree)
+
+ # forward
+ (
+ renders,
+ alphas,
+ normals,
+ normals_from_depth,
+ render_distort,
+ render_median,
+ info,
+ ) = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=sh_degree_to_use,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ image_ids=image_ids,
+ render_mode="RGB+ED" if cfg.depth_loss else "RGB+D",
+ distloss=self.cfg.dist_loss,
+ )
+ if renders.shape[-1] == 4:
+ colors, depths = renders[..., 0:3], renders[..., 3:4]
+ else:
+ colors, depths = renders, None
+
+ if cfg.random_bkgd:
+ bkgd = torch.rand(1, 3, device=device)
+ colors = colors + bkgd * (1.0 - alphas)
+
+ self.strategy.step_pre_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ )
+
+ # loss
+ l1loss = F.l1_loss(colors, pixels)
+ ssimloss = 1.0 - self.ssim(
+ pixels.permute(0, 3, 1, 2), colors.permute(0, 3, 1, 2)
+ )
+ loss = l1loss * (1.0 - cfg.ssim_lambda) + ssimloss * cfg.ssim_lambda
+ if cfg.depth_loss:
+ # query depths from depth map
+ points = torch.stack(
+ [
+ points[:, :, 0] / (width - 1) * 2 - 1,
+ points[:, :, 1] / (height - 1) * 2 - 1,
+ ],
+ dim=-1,
+ ) # normalize to [-1, 1]
+ grid = points.unsqueeze(2) # [1, M, 1, 2]
+ depths = F.grid_sample(
+ depths.permute(0, 3, 1, 2), grid, align_corners=True
+ ) # [1, 1, M, 1]
+ depths = depths.squeeze(3).squeeze(1) # [1, M]
+ # calculate loss in disparity space
+ disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths))
+ disp_gt = 1.0 / depths_gt # [1, M]
+ depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale
+ loss += depthloss * cfg.depth_lambda
+
+ if cfg.normal_loss:
+ if step > cfg.normal_start_iter:
+ curr_normal_lambda = cfg.normal_lambda
+ else:
+ curr_normal_lambda = 0.0
+ # normal consistency loss
+ normals = normals.squeeze(0).permute((2, 0, 1))
+ normals_from_depth *= alphas.squeeze(0).detach()
+ if len(normals_from_depth.shape) == 4:
+ normals_from_depth = normals_from_depth.squeeze(0)
+ normals_from_depth = normals_from_depth.permute((2, 0, 1))
+ normal_error = (1 - (normals * normals_from_depth).sum(dim=0))[None]
+ normalloss = curr_normal_lambda * normal_error.mean()
+ loss += normalloss
+
+ if cfg.dist_loss:
+ if step > cfg.dist_start_iter:
+ curr_dist_lambda = cfg.dist_lambda
+ else:
+ curr_dist_lambda = 0.0
+ distloss = render_distort.mean()
+ loss += distloss * curr_dist_lambda
+
+ loss.backward()
+
+ desc = f"loss={loss.item():.3f}| " f"sh degree={sh_degree_to_use}| "
+ if cfg.depth_loss:
+ desc += f"depth loss={depthloss.item():.6f}| "
+ if cfg.dist_loss:
+ desc += f"dist loss={distloss.item():.6f}"
+ if cfg.pose_opt and cfg.pose_noise:
+ # monitor the pose error if we inject noise
+ pose_err = F.l1_loss(camtoworlds_gt, camtoworlds)
+ desc += f"pose err={pose_err.item():.6f}| "
+ pbar.set_description(desc)
+
+ if cfg.tb_every > 0 and step % cfg.tb_every == 0:
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ self.writer.add_scalar("train/loss", loss.item(), step)
+ self.writer.add_scalar("train/l1loss", l1loss.item(), step)
+ self.writer.add_scalar("train/ssimloss", ssimloss.item(), step)
+ self.writer.add_scalar("train/num_GS", len(self.splats["means"]), step)
+ self.writer.add_scalar("train/mem", mem, step)
+ if cfg.depth_loss:
+ self.writer.add_scalar("train/depthloss", depthloss.item(), step)
+ if cfg.normal_loss:
+ self.writer.add_scalar("train/normalloss", normalloss.item(), step)
+ if cfg.dist_loss:
+ self.writer.add_scalar("train/distloss", distloss.item(), step)
+ if cfg.tb_save_image:
+ canvas = (
+ torch.cat([pixels, colors[..., :3]], dim=2)
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ canvas = canvas.reshape(-1, *canvas.shape[2:])
+ self.writer.add_image("train/render", canvas, step)
+ self.writer.flush()
+
+ self.strategy.step_post_backward(
+ params=self.splats,
+ optimizers=self.optimizers,
+ state=self.strategy_state,
+ step=step,
+ info=info,
+ packed=cfg.packed,
+ )
+
+ # Turn Gradients into Sparse Tensor before running optimizer
+ if cfg.sparse_grad:
+ assert cfg.packed, "Sparse gradients only work with packed mode."
+ gaussian_ids = info["gaussian_ids"]
+ for k in self.splats.keys():
+ grad = self.splats[k].grad
+ if grad is None or grad.is_sparse:
+ continue
+ self.splats[k].grad = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=grad[gaussian_ids], # [nnz, ...]
+ size=self.splats[k].size(), # [N, ...]
+ is_coalesced=len(Ks) == 1,
+ )
+
+ # optimize
+ for optimizer in self.optimizers.values():
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for optimizer in self.pose_optimizers:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for optimizer in self.app_optimizers:
+ optimizer.step()
+ optimizer.zero_grad(set_to_none=True)
+ for scheduler in schedulers:
+ scheduler.step()
+
+ # save checkpoint
+ if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1:
+ mem = torch.cuda.max_memory_allocated() / 1024**3
+ stats = {
+ "mem": mem,
+ "ellipse_time": time.time() - global_tic,
+ "num_GS": len(self.splats["means"]),
+ }
+ print("Step: ", step, stats)
+ with open(f"{self.stats_dir}/train_step{step:04d}.json", "w") as f:
+ json.dump(stats, f)
+ torch.save(
+ {
+ "step": step,
+ "splats": self.splats.state_dict(),
+ },
+ f"{self.ckpt_dir}/ckpt_{step}.pt",
+ )
+
+ # eval the full set
+ if step in [i - 1 for i in cfg.eval_steps] or step == max_steps - 1:
+ self.eval(step)
+ self.render_traj(step)
+
+ if not cfg.disable_viewer:
+ self.viewer.lock.release()
+ num_train_steps_per_sec = 1.0 / (time.time() - tic)
+ num_train_rays_per_sec = (
+ num_train_rays_per_step * num_train_steps_per_sec
+ )
+ # Update the viewer state.
+ self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec
+ # Update the scene.
+ self.viewer.update(step, num_train_rays_per_step)
+
+ @torch.no_grad()
+ def eval(self, step: int):
+ """Entry for evaluation."""
+ print("Running evaluation...")
+ cfg = self.cfg
+ device = self.device
+
+ valloader = torch.utils.data.DataLoader(
+ self.valset, batch_size=1, shuffle=False, num_workers=1
+ )
+ ellipse_time = 0
+ metrics = {"psnr": [], "ssim": [], "lpips": []}
+ for i, data in enumerate(valloader):
+ camtoworlds = data["camtoworld"].to(device)
+ Ks = data["K"].to(device)
+ pixels = data["image"].to(device) / 255.0
+ height, width = pixels.shape[1:3]
+
+ torch.cuda.synchronize()
+ tic = time.time()
+ (
+ colors,
+ alphas,
+ normals,
+ normals_from_depth,
+ render_distort,
+ render_median,
+ _,
+ ) = self.rasterize_splats(
+ camtoworlds=camtoworlds,
+ Ks=Ks,
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ render_mode="RGB+ED",
+ ) # [1, H, W, 3]
+ colors = torch.clamp(colors, 0.0, 1.0)
+ colors = colors[..., :3] # Take RGB channels
+ torch.cuda.synchronize()
+ ellipse_time += time.time() - tic
+
+ # write images
+ canvas = torch.cat([pixels, colors], dim=2).squeeze(0).cpu().numpy()
+ imageio.imwrite(
+ f"{self.render_dir}/val_{i:04d}.png", (canvas * 255).astype(np.uint8)
+ )
+
+ # write median depths
+ render_median = (render_median - render_median.min()) / (
+ render_median.max() - render_median.min()
+ )
+ # render_median = render_median.detach().cpu().squeeze(0).unsqueeze(-1).repeat(1, 1, 3).numpy()
+ render_median = (
+ render_median.detach().cpu().squeeze(0).repeat(1, 1, 3).numpy()
+ )
+
+ imageio.imwrite(
+ f"{self.render_dir}/val_{i:04d}_median_depth_{step}.png",
+ (render_median * 255).astype(np.uint8),
+ )
+
+ # write normals
+ normals = (normals * 0.5 + 0.5).squeeze(0).cpu().numpy()
+ normals_output = (normals * 255).astype(np.uint8)
+ imageio.imwrite(
+ f"{self.render_dir}/val_{i:04d}_normal_{step}.png", normals_output
+ )
+
+ # write normals from depth
+ normals_from_depth *= alphas.squeeze(0).detach()
+ normals_from_depth = (normals_from_depth * 0.5 + 0.5).cpu().numpy()
+ normals_from_depth = (normals_from_depth - np.min(normals_from_depth)) / (
+ np.max(normals_from_depth) - np.min(normals_from_depth)
+ )
+ normals_from_depth_output = (normals_from_depth * 255).astype(np.uint8)
+ if len(normals_from_depth_output.shape) == 4:
+ normals_from_depth_output = normals_from_depth_output.squeeze(0)
+ imageio.imwrite(
+ f"{self.render_dir}/val_{i:04d}_normals_from_depth_{step}.png",
+ normals_from_depth_output,
+ )
+
+ # write distortions
+
+ render_dist = render_distort
+ dist_max = torch.max(render_dist)
+ dist_min = torch.min(render_dist)
+ render_dist = (render_dist - dist_min) / (dist_max - dist_min)
+ render_dist = (
+ colormap(render_dist.cpu().numpy()[0])
+ .permute((1, 2, 0))
+ .numpy()
+ .astype(np.uint8)
+ )
+ imageio.imwrite(
+ f"{self.render_dir}/val_{i:04d}_distortions_{step}.png", render_dist
+ )
+
+ pixels = pixels.permute(0, 3, 1, 2) # [1, 3, H, W]
+ colors = colors.permute(0, 3, 1, 2) # [1, 3, H, W]
+ metrics["psnr"].append(self.psnr(colors, pixels))
+ metrics["ssim"].append(self.ssim(colors, pixels))
+ metrics["lpips"].append(self.lpips(colors, pixels))
+
+ ellipse_time /= len(valloader)
+
+ psnr = torch.stack(metrics["psnr"]).mean()
+ ssim = torch.stack(metrics["ssim"]).mean()
+ lpips = torch.stack(metrics["lpips"]).mean()
+ print(
+ f"PSNR: {psnr.item():.3f}, SSIM: {ssim.item():.4f}, LPIPS: {lpips.item():.3f} "
+ f"Time: {ellipse_time:.3f}s/image "
+ f"Number of GS: {len(self.splats['means'])}"
+ )
+ # save stats as json
+ stats = {
+ "psnr": psnr.item(),
+ "ssim": ssim.item(),
+ "lpips": lpips.item(),
+ "ellipse_time": ellipse_time,
+ "num_GS": len(self.splats["means"]),
+ }
+ with open(f"{self.stats_dir}/val_step{step:04d}.json", "w") as f:
+ json.dump(stats, f)
+ # save stats to tensorboard
+ for k, v in stats.items():
+ self.writer.add_scalar(f"val/{k}", v, step)
+ self.writer.flush()
+
+ @torch.no_grad()
+ def render_traj(self, step: int):
+ """Entry for trajectory rendering."""
+ print("Running trajectory rendering...")
+ cfg = self.cfg
+ device = self.device
+
+ camtoworlds = self.parser.camtoworlds[5:-5]
+ camtoworlds = generate_interpolated_path(camtoworlds, 1) # [N, 3, 4]
+ camtoworlds = np.concatenate(
+ [
+ camtoworlds,
+ np.repeat(np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds), axis=0),
+ ],
+ axis=1,
+ ) # [N, 4, 4]
+
+ camtoworlds = torch.from_numpy(camtoworlds).float().to(device)
+ K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device)
+ width, height = list(self.parser.imsize_dict.values())[0]
+
+ canvas_all = []
+ for i in tqdm.trange(len(camtoworlds), desc="Rendering trajectory"):
+ renders, _, _, surf_normals, _, _, _ = self.rasterize_splats(
+ camtoworlds=camtoworlds[i : i + 1],
+ Ks=K[None],
+ width=width,
+ height=height,
+ sh_degree=cfg.sh_degree,
+ near_plane=cfg.near_plane,
+ far_plane=cfg.far_plane,
+ render_mode="RGB+ED",
+ ) # [1, H, W, 4]
+ colors = torch.clamp(renders[0, ..., 0:3], 0.0, 1.0) # [H, W, 3]
+ depths = renders[0, ..., 3:4] # [H, W, 1]
+ depths = (depths - depths.min()) / (depths.max() - depths.min())
+
+ surf_normals = (surf_normals - surf_normals.min()) / (
+ surf_normals.max() - surf_normals.min()
+ )
+
+ # write images
+ canvas = torch.cat(
+ [colors, depths.repeat(1, 1, 3)], dim=1 if width > height else 1
+ )
+ canvas = (canvas.cpu().numpy() * 255).astype(np.uint8)
+ canvas_all.append(canvas)
+
+ # save to video
+ video_dir = f"{cfg.result_dir}/videos"
+ os.makedirs(video_dir, exist_ok=True)
+ writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30)
+ for canvas in canvas_all:
+ writer.append_data(canvas)
+ writer.close()
+ print(f"Video saved to {video_dir}/traj_{step}.mp4")
+
+ @torch.no_grad()
+ def _viewer_render_fn(
+ self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int]
+ ):
+ """Callable function for the viewer."""
+ W, H = img_wh
+ c2w = camera_state.c2w
+ K = camera_state.get_K(img_wh)
+ c2w = torch.from_numpy(c2w).float().to(self.device)
+ K = torch.from_numpy(K).float().to(self.device)
+
+ render_colors, _, _, _, _, _, _ = self.rasterize_splats(
+ camtoworlds=c2w[None],
+ Ks=K[None],
+ width=W,
+ height=H,
+ sh_degree=self.cfg.sh_degree, # active all SH degrees
+ radius_clip=3.0, # skip GSs that have small image radius (in pixels)
+ ) # [1, H, W, 3]
+ return render_colors[0].cpu().numpy()
+
+
+def main(cfg: Config):
+ runner = Runner(cfg)
+
+ if cfg.ckpt is not None:
+ # run eval only
+ ckpt = torch.load(cfg.ckpt, map_location=runner.device)
+ for k in runner.splats.keys():
+ runner.splats[k].data = ckpt["splats"][k]
+ runner.eval(step=ckpt["step"])
+ runner.render_traj(step=ckpt["step"])
+ else:
+ runner.train()
+
+ if not cfg.disable_viewer:
+ print("Viewer running... Ctrl+C to exit.")
+ time.sleep(1000000)
+
+
+if __name__ == "__main__":
+ cfg = tyro.cli(Config)
+ cfg.adjust_steps(cfg.steps_scaler)
+ main(cfg)
diff --git a/examples/utils.py b/examples/utils.py
index 750322e5f..b79f4244b 100644
--- a/examples/utils.py
+++ b/examples/utils.py
@@ -5,6 +5,8 @@
from sklearn.neighbors import NearestNeighbors
from torch import Tensor
import torch.nn.functional as F
+import matplotlib.pyplot as plt
+from matplotlib import colormaps
class CameraOptModule(torch.nn.Module):
@@ -152,3 +154,71 @@ def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
+
+
+# ref: https://github.com/hbb1/2d-gaussian-splatting/blob/main/utils/general_utils.py#L163
+def colormap(img, cmap="jet"):
+ W, H = img.shape[:2]
+ dpi = 300
+ fig, ax = plt.subplots(1, figsize=(H / dpi, W / dpi), dpi=dpi)
+ im = ax.imshow(img, cmap=cmap)
+ ax.set_axis_off()
+ fig.colorbar(im, ax=ax)
+ fig.tight_layout()
+ fig.canvas.draw()
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
+ img = torch.from_numpy(data).float().permute(2, 0, 1)
+ plt.close()
+ return img
+
+
+def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
+ """Convert single channel to a color img.
+
+ Args:
+ img (torch.Tensor): (..., 1) float32 single channel image.
+ colormap (str): Colormap for img.
+
+ Returns:
+ (..., 3) colored img with colors in [0, 1].
+ """
+ img = torch.nan_to_num(img, 0)
+ if colormap == "gray":
+ return img.repeat(1, 1, 3)
+ img_long = (img * 255).long()
+ img_long_min = torch.min(img_long)
+ img_long_max = torch.max(img_long)
+ assert img_long_min >= 0, f"the min value is {img_long_min}"
+ assert img_long_max <= 255, f"the max value is {img_long_max}"
+ return torch.tensor(
+ colormaps[colormap].colors, # type: ignore
+ device=img.device,
+ )[img_long[..., 0]]
+
+
+def apply_depth_colormap(
+ depth: torch.Tensor,
+ acc: torch.Tensor = None,
+ near_plane: float = None,
+ far_plane: float = None,
+) -> torch.Tensor:
+ """Converts a depth image to color for easier analysis.
+
+ Args:
+ depth (torch.Tensor): (..., 1) float32 depth.
+ acc (torch.Tensor | None): (..., 1) optional accumulation mask.
+ near_plane: Closest depth to consider. If None, use min image value.
+ far_plane: Furthest depth to consider. If None, use max image value.
+
+ Returns:
+ (..., 3) colored depth image with colors in [0, 1].
+ """
+ near_plane = near_plane or float(torch.min(depth))
+ far_plane = far_plane or float(torch.max(depth))
+ depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
+ depth = torch.clip(depth, 0.0, 1.0)
+ img = apply_float_colormap(depth, colormap="turbo")
+ if acc is not None:
+ img = img * acc + (1.0 - acc)
+ return img
diff --git a/gsplat/__init__.py b/gsplat/__init__.py
index 8982d8a3c..df47d1555 100644
--- a/gsplat/__init__.py
+++ b/gsplat/__init__.py
@@ -2,6 +2,7 @@
from .compression import PngCompression
from .cuda._torch_impl import accumulate
+from .cuda._torch_impl_2dgs import accumulate_2dgs
from .cuda._wrapper import (
fully_fused_projection,
isect_offset_encode,
@@ -12,21 +13,26 @@
rasterize_to_pixels,
spherical_harmonics,
world_to_cam,
+ fully_fused_projection_2dgs,
+ rasterize_to_pixels_2dgs,
+ rasterize_to_indices_in_range_2dgs,
)
from .rendering import (
rasterization,
+ rasterization_2dgs,
rasterization_inria_wrapper,
+ rasterization_2dgs_inria_wrapper,
)
from .strategy import DefaultStrategy, MCMCStrategy, Strategy
from .version import __version__
-
all = [
"PngCompression",
"DefaultStrategy",
"MCMCStrategy",
"Strategy",
"rasterization",
+ "rasterization_2dgs",
"rasterization_inria_wrapper",
"spherical_harmonics",
"isect_offset_encode",
@@ -38,5 +44,9 @@
"world_to_cam",
"accumulate",
"rasterize_to_indices_in_range",
- "__version__",
+ "full_fused_projection_2dgs",
+ "rasterize_to_pixels_2dgs",
+ "rasterize_to_indices_in_range_2dgs",
+ "accumulate_2dgs",
+ "rasterization_2dgs_inria_wrapper" "__version__",
]
diff --git a/gsplat/cuda/_torch_impl.py b/gsplat/cuda/_torch_impl.py
index 2585e36b3..1b06e8a59 100644
--- a/gsplat/cuda/_torch_impl.py
+++ b/gsplat/cuda/_torch_impl.py
@@ -6,14 +6,8 @@
from torch import Tensor
-def _quat_scale_to_covar_preci(
- quats: Tensor, # [N, 4],
- scales: Tensor, # [N, 3],
- compute_covar: bool = True,
- compute_preci: bool = True,
- triu: bool = False,
-) -> Tuple[Optional[Tensor], Optional[Tensor]]:
- """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`."""
+def _quat_to_rotmat(quats: Tensor) -> Tensor:
+ """Convert quaternion to rotation matrix."""
quats = F.normalize(quats, p=2, dim=-1)
w, x, y, z = torch.unbind(quats, dim=-1)
R = torch.stack(
@@ -30,9 +24,28 @@ def _quat_scale_to_covar_preci(
],
dim=-1,
)
+ return R.reshape(quats.shape[:-1] + (3, 3))
+
+
+def _quat_scale_to_matrix(
+ quats: Tensor, # [N, 4],
+ scales: Tensor, # [N, 3],
+) -> Tensor:
+ """Convert quaternion and scale to a 3x3 matrix (R * S)."""
+ R = _quat_to_rotmat(quats) # (..., 3, 3)
+ M = R * scales[..., None, :] # (..., 3, 3)
+ return M
+
- R = R.reshape(quats.shape[:-1] + (3, 3)) # (..., 3, 3)
- # R.register_hook(lambda grad: print("grad R", grad))
+def _quat_scale_to_covar_preci(
+ quats: Tensor, # [N, 4],
+ scales: Tensor, # [N, 3],
+ compute_covar: bool = True,
+ compute_preci: bool = True,
+ triu: bool = False,
+) -> Tuple[Optional[Tensor], Optional[Tensor]]:
+ """PyTorch implementation of `gsplat.cuda._wrapper.quat_scale_to_covar_preci()`."""
+ R = _quat_to_rotmat(quats) # (..., 3, 3)
if compute_covar:
M = R * scales[..., None, :] # (..., 3, 3)
diff --git a/gsplat/cuda/_torch_impl_2dgs.py b/gsplat/cuda/_torch_impl_2dgs.py
new file mode 100644
index 000000000..11f85546e
--- /dev/null
+++ b/gsplat/cuda/_torch_impl_2dgs.py
@@ -0,0 +1,272 @@
+from typing import Optional, Tuple
+
+import torch
+from torch import Tensor
+
+from gsplat.cuda._torch_impl import _quat_scale_to_matrix
+
+
+def _fully_fused_projection_2dgs(
+ means: Tensor, # [N, 3]
+ quats: Tensor, # [N, 4]
+ scales: Tensor, # [N, 3]
+ viewmats: Tensor, # [C, 4, 4]
+ Ks: Tensor, # [C, 3, 3]
+ width: int,
+ height: int,
+ near_plane: float = 0.01,
+ far_plane: float = 1e10,
+ eps: float = 1e-6,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """PyTorch implementation of `gsplat.cuda._wrapper.fully_fused_projection_2dgs()`
+
+ .. note::
+
+ This is a minimal implementation of fully fused version, which has more
+ arguments. Not all arguments are supported.
+ """
+ R_cw = viewmats[:, :3, :3] # [C, 3, 3]
+ t_cw = viewmats[:, :3, 3] # [C, 3]
+ means_c = torch.einsum("cij,nj->cni", R_cw, means) + t_cw[:, None, :] # (C, N, 3)
+ RS_wl = _quat_scale_to_matrix(quats, scales)
+ RS_cl = torch.einsum("cij,njk->cnik", R_cw, RS_wl) # [C, N, 3, 3]
+
+ # compute normals
+ normals = RS_cl[..., 2] # [C, N, 3]
+ C, N, _ = normals.shape
+ cos = -normals.reshape((C * N, 1, 3)) @ means_c.reshape((C * N, 3, 1))
+ cos = cos.reshape((C, N, 1))
+ multiplier = torch.where(cos > 0, torch.tensor(1.0), torch.tensor(-1.0))
+ normals *= multiplier
+
+ # ray transform matrix, omitting the z rotation
+ T_cl = torch.cat([RS_cl[..., :2], means_c[..., None]], dim=-1) # [C, N, 3, 3]
+ T_sl = torch.einsum("cij,cnjk->cnik", Ks[:, :3, :3], T_cl) # [C, N, 3, 3]
+ # in paper notation M = (WH)^T
+ # later h_u = M @ h_x, h_v = M @ h_y
+ M = torch.transpose(T_sl, -1, -2) # [C, N, 3, 3]
+
+ # compute the AABB of gaussian
+ test = torch.tensor([1.0, 1.0, -1.0], device=means.device).reshape(1, 1, 3)
+ d = (M[..., 2] * M[..., 2] * test).sum(dim=-1, keepdim=True) # [C, N, 1]
+ valid = torch.abs(d) > eps
+ f = torch.where(valid, test / d, torch.zeros_like(test)).unsqueeze(
+ -1
+ ) # (C, N, 3, 1)
+ means2d = (M[..., :2] * M[..., 2:3] * f).sum(dim=-2) # [C, N, 2]
+ extents = torch.sqrt(
+ means2d**2 - (M[..., :2] * M[..., :2] * f).sum(dim=-2)
+ ) # [C, N, 2]
+
+ depths = means_c[..., 2] # [C, N]
+ radius = torch.ceil(3.0 * torch.max(extents, dim=-1).values) # (C, N)
+
+ valid = valid.squeeze(-1) & (depths > near_plane) & (depths < far_plane)
+ radius[~valid] = 0.0
+
+ inside = (
+ (means2d[..., 0] + radius > 0)
+ & (means2d[..., 0] - radius < width)
+ & (means2d[..., 1] + radius > 0)
+ & (means2d[..., 1] - radius < height)
+ )
+ radius[~inside] = 0.0
+ radii = radius.int()
+ return radii, means2d, depths, M, normals
+
+
+def accumulate_2dgs(
+ means2d: Tensor, # [C, N, 2]
+ ray_transforms: Tensor, # [C, N, 3, 3]
+ opacities: Tensor, # [C, N]
+ colors: Tensor, # [C, N, channels]
+ normals: Tensor, # [C, N, 3]
+ gaussian_ids: Tensor, # [M]
+ pixel_ids: Tensor, # [M]
+ camera_ids: Tensor, # [M]
+ image_width: int,
+ image_height: int,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ """Alpha compositing for 2DGS.
+
+ .. warning::
+ This function requires the nerfacc package to be installed. Please install it using the following command pip install nerfacc.
+
+ Args:
+ means2d: Gaussian means in 2D. [C, N, 2]
+ ray_transforms: transformation matrices that transform rays in pixel space into splat's local frame. [C, N, 3, 3]
+ opacities: Per-view Gaussian opacities (for example, when antialiasing is enabled, Gaussian in
+ each view would efficiently have different opacity). [C, N]
+ colors: Per-view Gaussian colors. Supports N-D features. [C, N, channels]
+ normals: Per-view Gaussian normals. [C, N, 3]
+ gaussian_ids: Collection of Gaussian indices to be rasterized. A flattened list of shape [M].
+ pixel_ids: Collection of pixel indices (row-major) to be rasterized. A flattened list of shape [M].
+ camera_ids: Collection of camera indices to be rasterized. A flattened list of shape [M].
+ image_width: Image width.
+ image_height: Image height.
+
+ Returns:
+ A tuple:
+
+ **renders**: Accumulated colors. [C, image_height, image_width, channels]
+
+ **alphas**: Accumulated opacities. [C, image_height, image_width, 1]
+
+ **normals**: Accumulated opacities. [C, image_height, image_width, 3]
+ """
+
+ try:
+ from nerfacc import accumulate_along_rays, render_weight_from_alpha
+ except ImportError:
+ raise ImportError("Please install nerfacc package: pip install nerfacc")
+
+ C, N = means2d.shape[:2]
+ channels = colors.shape[-1]
+
+ pixel_ids_x = pixel_ids % image_width + 0.5
+ pixel_ids_y = pixel_ids // image_width + 0.5
+ pixel_coords = torch.stack([pixel_ids_x, pixel_ids_y], dim=-1) # [M, 2]
+ deltas = pixel_coords - means2d[camera_ids, gaussian_ids] # [M, 2]
+
+ M = ray_transforms[camera_ids, gaussian_ids] # [M, 3, 3]
+
+ h_u = -M[..., 0, :3] + M[..., 2, :3] * pixel_ids_x[..., None] # [M, 3]
+ h_v = -M[..., 1, :3] + M[..., 2, :3] * pixel_ids_y[..., None] # [M, 3]
+ tmp = torch.cross(h_u, h_v, dim=-1)
+ us = tmp[..., 0] / tmp[..., 2]
+ vs = tmp[..., 1] / tmp[..., 2]
+ sigmas_3d = us**2 + vs**2 # [M]
+ sigmas_2d = 2 * (deltas[..., 0] ** 2 + deltas[..., 1] ** 2)
+ sigmas = 0.5 * torch.minimum(sigmas_3d, sigmas_2d) # [M]
+
+ alphas = torch.clamp_max(
+ opacities[camera_ids, gaussian_ids] * torch.exp(-sigmas), 0.999
+ )
+
+ indices = camera_ids * image_height * image_width + pixel_ids
+ total_pixels = C * image_height * image_width
+
+ weights, trans = render_weight_from_alpha(
+ alphas, ray_indices=indices, n_rays=total_pixels
+ )
+ renders = accumulate_along_rays(
+ weights,
+ colors[camera_ids, gaussian_ids],
+ ray_indices=indices,
+ n_rays=total_pixels,
+ ).reshape(C, image_height, image_width, channels)
+ alphas = accumulate_along_rays(
+ weights, None, ray_indices=indices, n_rays=total_pixels
+ ).reshape(C, image_height, image_width, 1)
+ renders_normal = accumulate_along_rays(
+ weights,
+ normals[camera_ids, gaussian_ids],
+ ray_indices=indices,
+ n_rays=total_pixels,
+ ).reshape(C, image_height, image_width, 3)
+
+ return renders, alphas, renders_normal
+
+
+def _rasterize_to_pixels_2dgs(
+ means2d: Tensor, # [C, N, 2]
+ ray_transforms: Tensor, # [C, N, 3, 3]
+ colors: Tensor, # [C, N, channels]
+ normals: Tensor, # [C, N, 3]
+ opacities: Tensor, # [C, N]
+ image_width: int,
+ image_height: int,
+ tile_size: int,
+ isect_offsets: Tensor, # [C, tile_height, tile_width]
+ flatten_ids: Tensor, # [n_isects]
+ backgrounds: Optional[Tensor] = None, # [C, channels]
+ batch_per_iter: int = 100,
+):
+ """Pytorch implementation of `gsplat.cuda._wrapper.rasterize_to_pixels_2dgs()`.
+
+ This function rasterizes 2D Gaussians to pixels in a Pytorch-friendly way. It
+ iteratively accumulates the renderings within each batch of Gaussians. The
+ interations are controlled by `batch_per_iter`.
+
+ .. note::
+ This is a minimal implementation of the fully fused version, which has more
+ arguments. Not all arguments are supported.
+
+ .. note::
+
+ This function relies on Pytorch's autograd for the backpropagation. It is much slower
+ than our fully fused rasterization implementation and comsumes much more GPU memory.
+ But it could serve as a playground for new ideas or debugging, as no backward
+ implementation is needed.
+
+ .. warning::
+
+ This function requires the `nerfacc` package to be installed. Please install it
+ using the following command `pip install nerfacc`.
+ """
+ from ._wrapper import rasterize_to_indices_in_range_2dgs
+
+ C, N = means2d.shape[:2]
+ n_isects = len(flatten_ids)
+ device = means2d.device
+
+ render_colors = torch.zeros(
+ (C, image_height, image_width, colors.shape[-1]), device=device
+ )
+ render_alphas = torch.zeros((C, image_height, image_width, 1), device=device)
+ render_normals = torch.zeros((C, image_height, image_width, 3), device=device)
+
+ # Split Gaussians into batches and iteratively accumulate the renderings
+ block_size = tile_size * tile_size
+ isect_offsets_fl = torch.cat(
+ [isect_offsets.flatten(), torch.tensor([n_isects], device=device)]
+ )
+ max_range = (isect_offsets_fl[1:] - isect_offsets_fl[:-1]).max().item()
+ num_batches = (max_range + block_size - 1) // block_size
+ for step in range(0, num_batches, batch_per_iter):
+ transmittances = 1.0 - render_alphas[..., 0]
+
+ # Find the M intersections between pixels and gaussians.
+ # Each intersection corresponds to a tuple (gs_id, pixel_id, camera_id)
+ gs_ids, pixel_ids, camera_ids = rasterize_to_indices_in_range_2dgs(
+ step,
+ step + batch_per_iter,
+ transmittances,
+ means2d,
+ ray_transforms,
+ opacities,
+ image_width,
+ image_height,
+ tile_size,
+ isect_offsets,
+ flatten_ids,
+ ) # [M], [M]
+ if len(gs_ids) == 0:
+ break
+
+ # Accumulate the renderings within this batch of Gaussians.
+ renders_step, accs_step, renders_normal_step = accumulate_2dgs(
+ means2d,
+ ray_transforms,
+ opacities,
+ colors,
+ normals,
+ gs_ids,
+ pixel_ids,
+ camera_ids,
+ image_width,
+ image_height,
+ )
+ render_colors = render_colors + renders_step * transmittances[..., None]
+ render_alphas = render_alphas + accs_step * transmittances[..., None]
+ render_normals = (
+ render_normals + renders_normal_step * transmittances[..., None]
+ )
+
+ render_alphas = render_alphas
+ if backgrounds is not None:
+ render_colors = render_colors + backgrounds[:, None, None, :] * (
+ 1.0 - render_alphas
+ )
+
+ return render_colors, render_alphas, render_normals
diff --git a/gsplat/cuda/_wrapper.py b/gsplat/cuda/_wrapper.py
index ded7d5989..79dcfd29f 100644
--- a/gsplat/cuda/_wrapper.py
+++ b/gsplat/cuda/_wrapper.py
@@ -1207,3 +1207,743 @@ def backward(ctx, v_colors: Tensor):
if not compute_v_dirs:
v_dirs = None
return None, v_dirs, v_coeffs, None
+
+
+###### 2DGS ######
+def fully_fused_projection_2dgs(
+ means: Tensor, # [N, 3]
+ quats: Tensor, # [N, 4]
+ scales: Tensor,
+ viewmats: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ eps2d: float = 0.3,
+ near_plane: float = 0.01,
+ far_plane: float = 1e10,
+ radius_clip: float = 0.0,
+ packed: bool = False,
+ sparse_grad: bool = False,
+) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Prepare Gaussians for rasterization
+
+ This function prepares ray-splat intersection matrices, computes
+ per splat bounding box and 2D means in image space.
+
+ Args:
+ means: Gaussian means. [N, 3]
+ quats: Quaternions (No need to be normalized). [N, 4].
+ scales: Scales. [N, 3].
+ viewmats: Camera-to-world matrices. [C, 4, 4]
+ Ks: Camera intrinsics. [C, 3, 3]
+ width: Image width.
+ height: Image height.
+ near_plane: Near plane distance. Default: 0.01.
+ far_plane: Far plane distance. Default: 200.
+ radius_clip: Gaussians with projected radii smaller than this value will be ignored. Default: 0.0.
+ packed: If True, the output tensors will be packed into a flattened tensor. Default: False.
+ sparse_grad (Experimental): This is only effective when `packed` is True. If True, during backward the gradients
+ of {`means`, `covars`, `quats`, `scales`} will be a sparse Tensor in COO layout. Default: False.
+
+ Returns:
+ A tuple:
+
+ If `packed` is True:
+
+ - **camera_ids**. The row indices of the projected Gaussians. Int32 tensor of shape [nnz].
+ - **gaussian_ids**. The column indices of the projected Gaussians. Int32 tensor of shape [nnz].
+ - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [nnz].
+ - **means**. Projected Gaussian means in 2D. [nnz, 2]
+ - **depths**. The z-depth of the projected Gaussians. [nnz]
+ - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates (WH)^T in equation (9) in paper [nnz, 3, 3]
+ - **normals**. The normals in camera spaces. [nnz, 3]
+
+ If `packed` is False:
+
+ - **radii**. The maximum radius of the projected Gaussians in pixel unit. Int32 tensor of shape [C, N].
+ - **means**. Projected Gaussian means in 2D. [C, N, 2]
+ - **depths**. The z-depth of the projected Gaussians. [C, N]
+ - **ray_transforms**. transformation matrices that transforms xy-planes in pixel spaces into splat coordinates.
+ - **normals**. The normals in camera spaces. [C, N, 3]
+
+ """
+ C = viewmats.size(0)
+ N = means.size(0)
+ assert means.size() == (N, 3), means.size()
+ assert viewmats.size() == (C, 4, 4), viewmats.size()
+ assert Ks.size() == (C, 3, 3), Ks.size()
+ means = means.contiguous()
+ assert quats is not None, "quats is required"
+ assert scales is not None, "scales is required"
+ assert quats.size() == (N, 4), quats.size()
+ assert scales.size() == (N, 3), scales.size()
+ quats = quats.contiguous()
+ scales = scales.contiguous()
+ if sparse_grad:
+ assert packed, "sparse_grad is only supported when packed is True"
+
+ viewmats = viewmats.contiguous()
+ Ks = Ks.contiguous()
+ if packed:
+ return _FullyFusedProjectionPacked2DGS.apply(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ near_plane,
+ far_plane,
+ radius_clip,
+ sparse_grad,
+ )
+ else:
+ return _FullyFusedProjection2DGS.apply(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ eps2d,
+ near_plane,
+ far_plane,
+ radius_clip,
+ )
+
+
+class _FullyFusedProjection2DGS(torch.autograd.Function):
+ """Projects Gaussians to 2D."""
+
+ @staticmethod
+ def forward(
+ ctx,
+ means: Tensor,
+ quats: Tensor,
+ scales: Tensor,
+ viewmats: Tensor,
+ Ks: Tensor,
+ width: int,
+ height: int,
+ eps2d: float,
+ near_plane: float,
+ far_plane: float,
+ radius_clip: float,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ radii, means2d, depths, ray_transforms, normals = _make_lazy_cuda_func(
+ "fully_fused_projection_fwd_2dgs"
+ )(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ eps2d,
+ near_plane,
+ far_plane,
+ radius_clip,
+ )
+ ctx.save_for_backward(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ radii,
+ ray_transforms,
+ normals,
+ )
+ ctx.width = width
+ ctx.height = height
+ ctx.eps2d = eps2d
+
+ return radii, means2d, depths, ray_transforms, normals
+
+ @staticmethod
+ def backward(ctx, v_radii, v_means2d, v_depths, v_ray_transforms, v_normals):
+ (
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ radii,
+ ray_transforms,
+ normals,
+ ) = ctx.saved_tensors
+ width = ctx.width
+ height = ctx.height
+ eps2d = ctx.eps2d
+ v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
+ "fully_fused_projection_bwd_2dgs"
+ )(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ radii,
+ ray_transforms,
+ v_means2d.contiguous(),
+ v_depths.contiguous(),
+ v_normals.contiguous(),
+ v_ray_transforms.contiguous(),
+ ctx.needs_input_grad[3], # viewmats_requires_grad
+ )
+ if not ctx.needs_input_grad[0]:
+ v_means = None
+ if not ctx.needs_input_grad[1]:
+ v_quats = None
+ if not ctx.needs_input_grad[2]:
+ v_scales = None
+ if not ctx.needs_input_grad[3]:
+ v_viewmats = None
+
+ return (
+ v_means,
+ v_quats,
+ v_scales,
+ v_viewmats,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+class _FullyFusedProjectionPacked2DGS(torch.autograd.Function):
+ """Projects Gaussians to 2D. Return packed tensors."""
+
+ @staticmethod
+ def forward(
+ ctx,
+ means: Tensor, # [N, 3]
+ quats: Tensor, # [N, 4]
+ scales: Tensor, # [N, 3]
+ viewmats: Tensor, # [C, 4, 4]
+ Ks: Tensor, # [C, 3, 3]
+ width: int,
+ height: int,
+ near_plane: float,
+ far_plane: float,
+ radius_clip: float,
+ sparse_grad: bool,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ (
+ indptr,
+ camera_ids,
+ gaussian_ids,
+ radii,
+ means2d,
+ depths,
+ ray_transforms,
+ normals,
+ ) = _make_lazy_cuda_func("fully_fused_projection_packed_fwd_2dgs")(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ near_plane,
+ far_plane,
+ radius_clip,
+ )
+ ctx.save_for_backward(
+ camera_ids,
+ gaussian_ids,
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ ray_transforms,
+ )
+ ctx.width = width
+ ctx.height = height
+ ctx.sparse_grad = sparse_grad
+
+ return camera_ids, gaussian_ids, radii, means2d, depths, ray_transforms, normals
+
+ @staticmethod
+ def backward(
+ ctx,
+ v_camera_ids,
+ v_gaussian_ids,
+ v_radii,
+ v_means2d,
+ v_depths,
+ v_ray_transforms,
+ v_normals,
+ ):
+ (
+ camera_ids,
+ gaussian_ids,
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ ray_transforms,
+ ) = ctx.saved_tensors
+ width = ctx.width
+ height = ctx.height
+ sparse_grad = ctx.sparse_grad
+
+ v_means, v_quats, v_scales, v_viewmats = _make_lazy_cuda_func(
+ "fully_fused_projection_packed_bwd_2dgs"
+ )(
+ means,
+ quats,
+ scales,
+ viewmats,
+ Ks,
+ width,
+ height,
+ camera_ids,
+ gaussian_ids,
+ ray_transforms,
+ v_means2d.contiguous(),
+ v_depths.contiguous(),
+ v_ray_transforms.contiguous(),
+ v_normals.contiguous(),
+ ctx.needs_input_grad[4], # viewmats_requires_grad
+ sparse_grad,
+ )
+
+ if not ctx.needs_input_grad[0]:
+ v_means = None
+ else:
+ if sparse_grad:
+ # TODO: gaussian_ids is duplicated so not ideal.
+ # An idea is to directly set the attribute (e.g., .sparse_grad) of
+ # the tensor but this requires the tensor to be leaf node only. And
+ # a customized optimizer would be needed in this case.
+ v_means = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=v_means, # [nnz, 3]
+ size=means.size(), # [N, 3]
+ is_coalesced=len(viewmats) == 1,
+ )
+ if not ctx.needs_input_grad[1]:
+ v_quats = None
+ else:
+ if sparse_grad:
+ v_quats = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=v_quats, # [nnz, 4]
+ size=quats.size(), # [N, 4]
+ is_coalesced=len(viewmats) == 1,
+ )
+ if not ctx.needs_input_grad[2]:
+ v_scales = None
+ else:
+ if sparse_grad:
+ v_scales = torch.sparse_coo_tensor(
+ indices=gaussian_ids[None], # [1, nnz]
+ values=v_scales, # [nnz, 3]
+ size=scales.size(), # [N, 3]
+ is_coalesced=len(viewmats) == 1,
+ )
+ if not ctx.needs_input_grad[4]:
+ v_viewmats = None
+
+ return (
+ v_means,
+ v_quats,
+ v_scales,
+ v_viewmats,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+
+
+def rasterize_to_pixels_2dgs(
+ means2d: Tensor,
+ ray_transforms: Tensor,
+ colors: Tensor,
+ opacities: Tensor,
+ normals: Tensor,
+ densify: Tensor,
+ image_width: int,
+ image_height: int,
+ tile_size: int,
+ isect_offsets: Tensor,
+ flatten_ids: Tensor,
+ backgrounds: Optional[Tensor] = None,
+ masks: Optional[Tensor] = None,
+ packed: bool = False,
+ absgrad: bool = False,
+ distloss: bool = False,
+) -> Tuple[Tensor, Tensor]:
+ """Rasterize Gaussians to pixels.
+
+ Args:
+ means2d: Projected Gaussian means. [C, N, 2] if packed is False, [nnz, 2] if packed is True.
+ ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3] if packed is False, [nnz, channels] if packed is True.
+ colors: Gaussian colors or ND features. [C, N, channels] if packed is False, [nnz, channels] if packed is True.
+ opacities: Gaussian opacities that support per-view values. [C, N] if packed is False, [nnz] if packed is True.
+ normals: The normals in camera space. [C, N, 3] if packed is False, [nnz, 3] if packed is True.
+ densify: Dummy variable to keep track of gradient for densification. [C, N, 2] if packed, [nnz, 3] if packed is True.
+ tile_size: Tile size.
+ isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width]
+ flatten_ids: The global flatten indices in [C * N] or [nnz] from `isect_tiles()`. [n_isects]
+ backgrounds: Background colors. [C, channels]. Default: None.
+ masks: Optional tile mask to skip rendering GS to masked tiles. [C, tile_height, tile_width]. Default: None.
+ packed: If True, the input tensors are expected to be packed with shape [nnz, ...]. Default: False.
+ absgrad: If True, the backward pass will compute a `.absgrad` attribute for `means2d`. Default: False.
+
+ Returns:
+ A tuple:
+
+ - **Rendered colors**. [C, image_height, image_width, channels]
+ - **Rendered alphas**. [C, image_height, image_width, 1]
+ - **Rendered normals**. [C, image_height, image_width, 3]
+ - **Rendered distortion**. [C, image_height, image_width, 1]
+ - **Rendered median depth**.[C, image_height, image_width, 1]
+
+
+ """
+ C = isect_offsets.size(0)
+ device = means2d.device
+ if packed:
+ nnz = means2d.size(0)
+ assert means2d.shape == (nnz, 2), means2d.shape
+ assert ray_transforms.shape == (nnz, 3, 3), ray_transforms.shape
+ assert colors.shape[0] == nnz, colors.shape
+ assert opacities.shape == (nnz,), opacities.shape
+ else:
+ N = means2d.size(1)
+ assert means2d.shape == (C, N, 2), means2d.shape
+ assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape
+ assert colors.shape[:2] == (C, N), colors.shape
+ assert opacities.shape == (C, N), opacities.shape
+ if backgrounds is not None:
+ assert backgrounds.shape == (C, colors.shape[-1]), backgrounds.shape
+ backgrounds = backgrounds.contiguous()
+
+ # Pad the channels to the nearest supported number if necessary
+ channels = colors.shape[-1]
+ if channels > 512 or channels == 0:
+ # TODO: maybe worth to support zero channels?
+ raise ValueError(f"Unsupported number of color channels: {channels}")
+ if channels not in (1, 2, 3, 4, 8, 16, 32, 64, 128, 256, 512):
+ padded_channels = (1 << (channels - 1).bit_length()) - channels
+ colors = torch.cat(
+ [colors, torch.empty(*colors.shape[:-1], padded_channels, device=device)],
+ dim=-1,
+ )
+ if backgrounds is not None:
+ backgrounds = torch.cat(
+ [
+ backgrounds,
+ torch.empty(
+ *backgrounds.shape[:-1], padded_channels, device=device
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ padded_channels = 0
+ tile_height, tile_width = isect_offsets.shape[1:3]
+ assert (
+ tile_height * tile_size >= image_height
+ ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
+ assert (
+ tile_width * tile_size >= image_width
+ ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
+
+ (
+ render_colors,
+ render_alphas,
+ render_normals,
+ render_distort,
+ render_median,
+ ) = _RasterizeToPixels2DGS.apply(
+ means2d.contiguous(),
+ ray_transforms.contiguous(),
+ colors.contiguous(),
+ opacities.contiguous(),
+ normals.contiguous(),
+ densify.contiguous(),
+ backgrounds,
+ masks,
+ image_width,
+ image_height,
+ tile_size,
+ isect_offsets.contiguous(),
+ flatten_ids.contiguous(),
+ absgrad,
+ distloss,
+ )
+
+ if padded_channels > 0:
+ render_colors = render_colors[..., :-padded_channels]
+
+ return render_colors, render_alphas, render_normals, render_distort, render_median
+
+
+@torch.no_grad()
+def rasterize_to_indices_in_range_2dgs(
+ range_start: int,
+ range_end: int,
+ transmittances: Tensor,
+ means2d: Tensor,
+ ray_transforms: Tensor,
+ opacities: Tensor,
+ image_width: int,
+ image_height: int,
+ tile_size: int,
+ isect_offsets: Tensor,
+ flatten_ids: Tensor,
+) -> Tuple[Tensor, Tensor, Tensor]:
+ """Rasterizes a batch of Gaussians to images but only returns the indices.
+
+ .. note::
+
+ This function supports iterative rasterization, in which each call of this function
+ will rasterize a batch of Gaussians from near to far, defined by `[range_start, range_end)`.
+ If a one-step full rasterization is desired, set `range_start` to 0 and `range_end` to a really
+ large number, e.g, 1e10.
+
+ Args:
+ range_start: The start batch of Gaussians to be rasterized (inclusive).
+ range_end: The end batch of Gaussians to be rasterized (exclusive).
+ transmittances: Currently transmittances. [C, image_height, image_width]
+ means2d: Projected Gaussian means. [C, N, 2]
+ ray_transforms: transformation matrices that transforms xy-planes in pixel spaces into splat coordinates. [C, N, 3, 3]
+ opacities: Gaussian opacities that support per-view values. [C, N]
+ image_width: Image width.
+ image_height: Image height.
+ tile_size: Tile size.
+ isect_offsets: Intersection offsets outputs from `isect_offset_encode()`. [C, tile_height, tile_width]
+ flatten_ids: The global flatten indices in [C * N] from `isect_tiles()`. [n_isects]
+
+ Returns:
+ A tuple:
+
+ - **Gaussian ids**. Gaussian ids for the pixel intersection. A flattened list of shape [M].
+ - **Pixel ids**. pixel indices (row-major). A flattened list of shape [M].
+ - **Camera ids**. Camera indices. A flattened list of shape [M].
+ """
+
+ C, N, _ = means2d.shape
+ assert ray_transforms.shape == (C, N, 3, 3), ray_transforms.shape
+ assert opacities.shape == (C, N), opacities.shape
+ assert isect_offsets.shape[0] == C, isect_offsets.shape
+
+ tile_height, tile_width = isect_offsets.shape[1:3]
+ assert (
+ tile_height * tile_size >= image_height
+ ), f"Assert Failed: {tile_height} * {tile_size} >= {image_height}"
+ assert (
+ tile_width * tile_size >= image_width
+ ), f"Assert Failed: {tile_width} * {tile_size} >= {image_width}"
+
+ out_gauss_ids, out_indices = _make_lazy_cuda_func(
+ "rasterize_to_indices_in_range_2dgs"
+ )(
+ range_start,
+ range_end,
+ transmittances.contiguous(),
+ means2d.contiguous(),
+ ray_transforms.contiguous(),
+ opacities.contiguous(),
+ image_width,
+ image_height,
+ tile_size,
+ isect_offsets.contiguous(),
+ flatten_ids.contiguous(),
+ )
+ out_pixel_ids = out_indices % (image_width * image_height)
+ out_camera_ids = out_indices // (image_width * image_height)
+ return out_gauss_ids, out_pixel_ids, out_camera_ids
+
+
+class _RasterizeToPixels2DGS(torch.autograd.Function):
+ """Rasterize gaussians 2DGS"""
+
+ @staticmethod
+ def forward(
+ ctx,
+ means2d: Tensor,
+ ray_transforms: Tensor,
+ colors: Tensor,
+ opacities: Tensor,
+ normals: Tensor,
+ densify: Tensor,
+ backgrounds: Tensor,
+ masks: Tensor,
+ width: int,
+ height: int,
+ tile_size: int,
+ isect_offsets: Tensor,
+ flatten_ids: Tensor,
+ absgrad: bool,
+ distloss: bool,
+ ) -> Tuple[Tensor, Tensor]:
+ (
+ render_colors,
+ render_alphas,
+ render_normals,
+ render_distort,
+ render_median,
+ last_ids,
+ median_ids,
+ ) = _make_lazy_cuda_func("rasterize_to_pixels_fwd_2dgs")(
+ means2d,
+ ray_transforms,
+ colors,
+ opacities,
+ normals,
+ backgrounds,
+ masks,
+ width,
+ height,
+ tile_size,
+ isect_offsets,
+ flatten_ids,
+ )
+
+ ctx.save_for_backward(
+ means2d,
+ ray_transforms,
+ colors,
+ opacities,
+ normals,
+ densify,
+ backgrounds,
+ masks,
+ isect_offsets,
+ flatten_ids,
+ render_colors,
+ render_alphas,
+ last_ids,
+ median_ids,
+ )
+ ctx.width = width
+ ctx.height = height
+ ctx.tile_size = tile_size
+ ctx.absgrad = absgrad
+ ctx.distloss = distloss
+
+ # doubel to float
+ render_alphas = render_alphas.float()
+ return (
+ render_colors,
+ render_alphas,
+ render_normals,
+ render_distort,
+ render_median,
+ )
+
+ @staticmethod
+ def backward(
+ ctx,
+ v_render_colors: Tensor,
+ v_render_alphas: Tensor,
+ v_render_normals: Tensor,
+ v_render_distort: Tensor,
+ v_render_median: Tensor,
+ ):
+
+ (
+ means2d,
+ ray_transforms,
+ colors,
+ opacities,
+ normals,
+ densify,
+ backgrounds,
+ masks,
+ isect_offsets,
+ flatten_ids,
+ render_colors,
+ render_alphas,
+ last_ids,
+ median_ids,
+ ) = ctx.saved_tensors
+ width = ctx.width
+ height = ctx.height
+ tile_size = ctx.tile_size
+ absgrad = ctx.absgrad
+
+ (
+ v_means2d_abs,
+ v_means2d,
+ v_ray_transforms,
+ v_colors,
+ v_opacities,
+ v_normals,
+ v_densify,
+ ) = _make_lazy_cuda_func("rasterize_to_pixels_bwd_2dgs")(
+ means2d,
+ ray_transforms,
+ colors,
+ opacities,
+ normals,
+ densify,
+ backgrounds,
+ masks,
+ width,
+ height,
+ tile_size,
+ isect_offsets,
+ flatten_ids,
+ render_colors,
+ render_alphas,
+ last_ids,
+ median_ids,
+ v_render_colors.contiguous(),
+ v_render_alphas.contiguous(),
+ v_render_normals.contiguous(),
+ v_render_distort.contiguous(),
+ v_render_median.contiguous(),
+ absgrad,
+ )
+ torch.cuda.synchronize()
+ if absgrad:
+ means2d.absgrad = v_means2d_abs
+
+ if ctx.needs_input_grad[6]:
+ v_backgrounds = (v_render_colors * (1.0 - render_alphas).float()).sum(
+ dim=(1, 2)
+ )
+ else:
+ v_backgrounds = None
+
+ return (
+ v_means2d,
+ v_ray_transforms,
+ v_colors,
+ v_opacities,
+ v_normals,
+ v_densify,
+ v_backgrounds,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
diff --git a/gsplat/cuda/csrc/CMakeLists.txt b/gsplat/cuda/csrc/CMakeLists.txt
index a315d9c91..0fbf46ba9 100644
--- a/gsplat/cuda/csrc/CMakeLists.txt
+++ b/gsplat/cuda/csrc/CMakeLists.txt
@@ -63,4 +63,4 @@ install(TARGETS gsplat
LIBRARY DESTINATION lib
)
-message(STATUS "CMake configuration done!")
+message(STATUS "CMake configuration done!")
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/bindings.h b/gsplat/cuda/csrc/bindings.h
index 004709a4d..9e2f8328c 100644
--- a/gsplat/cuda/csrc/bindings.h
+++ b/gsplat/cuda/csrc/bindings.h
@@ -309,6 +309,178 @@ std::tuple compute_relocation_tensor(
const int n_max
);
+//====== 2DGS ======//
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+fully_fused_projection_fwd_2dgs_tensor(
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const float eps2d,
+ const float near_plane,
+ const float far_plane,
+ const float radius_clip
+);
+
+std::tuple
+fully_fused_projection_bwd_2dgs_tensor(
+ // fwd inputs
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ // fwd outputs
+ const torch::Tensor &radii, // [C, N]
+ const torch::Tensor &ray_transforms, // [C, N, 3, 3]
+ // grad outputs
+ const torch::Tensor &v_means2d, // [C, N, 2]
+ const torch::Tensor &v_depths, // [C, N]
+ const torch::Tensor &v_normals, // [C, N, 3]
+ const torch::Tensor &v_ray_transforms, // [C, N, 3, 3]
+ const bool viewmats_requires_grad
+);
+
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+rasterize_to_pixels_fwd_2dgs_tensor(
+ // Gaussian parameters
+ const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2]
+ const torch::Tensor &ray_transforms, // [C, N, 3] or [nnz, 3]
+ const torch::Tensor &colors, // [C, N, channels] or [nnz, channels]
+ const torch::Tensor &opacities, // [C, N] or [nnz]
+ const torch::Tensor &normals, // [C, N, 3] or [nnz, 3]
+ const at::optional &backgrounds, // [C, channels]
+ const at::optional &masks, // [C, tile_height, tile_width]
+ // image size
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ // intersections
+ const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
+ const torch::Tensor &flatten_ids // [n_isects]
+);
+
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+rasterize_to_pixels_bwd_2dgs_tensor(
+ // Gaussian parameters
+ const torch::Tensor &means2d, // [C, N, 2] or [nnz, 2]
+ const torch::Tensor &ray_transforms, // [C, N, 3, 3] or [nnz, 3, 3]
+ const torch::Tensor &colors, // [C, N, 3] or [nnz, 3]
+ const torch::Tensor &opacities, // [C, N] or [nnz]
+ const torch::Tensor &normals, // [C, N, 3] or [nnz, 3],
+ const torch::Tensor &densify,
+ const at::optional &backgrounds, // [C, 3]
+ const at::optional &masks, // [C, tile_height, tile_width]
+ // image size
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ // ray_crossions
+ const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
+ const torch::Tensor &flatten_ids, // [n_isects]
+ // forward outputs
+ const torch::Tensor
+ &render_colors, // [C, image_height, image_width, COLOR_DIM]
+ const torch::Tensor &render_alphas, // [C, image_height, image_width, 1]
+ const torch::Tensor &last_ids, // [C, image_height, image_width]
+ const torch::Tensor &median_ids, // [C, image_height, image_width]
+ // gradients of outputs
+ const torch::Tensor &v_render_colors, // [C, image_height, image_width, 3]
+ const torch::Tensor &v_render_alphas, // [C, image_height, image_width, 1]
+ const torch::Tensor &v_render_normals, // [C, image_height, image_width, 3]
+ const torch::Tensor &v_render_distort, // [C, image_height, image_width, 1]
+ const torch::Tensor &v_render_median, // [C, image_height, image_width, 1]
+ // options
+ bool absgrad
+);
+
+std::tuple
+rasterize_to_indices_in_range_2dgs_tensor(
+ const uint32_t range_start,
+ const uint32_t range_end, // iteration steps
+ const torch::Tensor transmittances, // [C, image_height, image_width]
+ // Gaussian parameters
+ const torch::Tensor &means2d, // [C, N, 2]
+ const torch::Tensor &ray_transforms, // [C, N, 3, 3]
+ const torch::Tensor &opacities, // [C, N]
+ // image size
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ // intersections
+ const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
+ const torch::Tensor &flatten_ids // [n_isects]
+);
+
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+fully_fused_projection_packed_fwd_2dgs_tensor(
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 3]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const float near_plane,
+ const float far_plane,
+ const float radius_clip
+);
+
+std::tuple
+fully_fused_projection_packed_bwd_2dgs_tensor(
+ // fwd inputs
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ // fwd outputs
+ const torch::Tensor &camera_ids, // [nnz]
+ const torch::Tensor &gaussian_ids, // [nnz]
+ const torch::Tensor &ray_transforms, // [nnz, 3, 3]
+ // grad outputs
+ const torch::Tensor &v_means2d, // [nnz, 2]
+ const torch::Tensor &v_depths, // [nnz]
+ const torch::Tensor &v_normals, // [nnz, 3]
+ const torch::Tensor &v_ray_transforms, // [nnz, 3, 3]
+ const bool viewmats_requires_grad,
+ const bool sparse_grad
+);
+
} // namespace gsplat
-#endif // GSPLAT_CUDA_BINDINGS_H
\ No newline at end of file
+#endif // GSPLAT_CUDA_BINDINGS_H
diff --git a/gsplat/cuda/csrc/compute_sh_bwd.cu b/gsplat/cuda/csrc/compute_sh_bwd.cu
index 4a4fda3f6..5b22cbdbe 100644
--- a/gsplat/cuda/csrc/compute_sh_bwd.cu
+++ b/gsplat/cuda/csrc/compute_sh_bwd.cu
@@ -93,4 +93,4 @@ std::tuple compute_sh_bwd_tensor(
return std::make_tuple(v_coeffs, v_dirs); // [..., K, 3], [..., 3]
}
-} // namespace gsplat
\ No newline at end of file
+} // namespace gsplat
diff --git a/gsplat/cuda/csrc/ext.cpp b/gsplat/cuda/csrc/ext.cpp
index 57248ac81..0a4a67aac 100644
--- a/gsplat/cuda/csrc/ext.cpp
+++ b/gsplat/cuda/csrc/ext.cpp
@@ -48,4 +48,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
);
m.def("compute_relocation", &gsplat::compute_relocation_tensor);
-}
\ No newline at end of file
+
+ // 2DGS
+ m.def(
+ "fully_fused_projection_fwd_2dgs",
+ &gsplat::fully_fused_projection_fwd_2dgs_tensor
+ );
+ m.def(
+ "fully_fused_projection_bwd_2dgs",
+ &gsplat::fully_fused_projection_bwd_2dgs_tensor
+ );
+
+ m.def(
+ "fully_fused_projection_packed_fwd_2dgs",
+ &gsplat::fully_fused_projection_packed_fwd_2dgs_tensor
+ );
+ m.def(
+ "fully_fused_projection_packed_bwd_2dgs",
+ &gsplat::fully_fused_projection_packed_bwd_2dgs_tensor
+ );
+
+ m.def(
+ "rasterize_to_pixels_fwd_2dgs",
+ &gsplat::rasterize_to_pixels_fwd_2dgs_tensor
+ );
+ m.def(
+ "rasterize_to_pixels_bwd_2dgs",
+ &gsplat::rasterize_to_pixels_bwd_2dgs_tensor
+ );
+
+ m.def(
+ "rasterize_to_indices_in_range_2dgs",
+ &gsplat::rasterize_to_indices_in_range_2dgs_tensor
+ );
+}
diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu
new file mode 100644
index 000000000..ec7eb1126
--- /dev/null
+++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_bwd.cu
@@ -0,0 +1,224 @@
+#include "bindings.h"
+#include "helpers.cuh"
+#include "utils.cuh"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Projection of Gaussians (Batched) Backward Pass
+ ****************************************************************************/
+template
+__global__ void fully_fused_projection_bwd_2dgs_kernel(
+ // fwd inputs
+ const uint32_t C,
+ const uint32_t N,
+ const T *__restrict__ means, // [N, 3]
+ const T *__restrict__ quats, // [N, 4]
+ const T *__restrict__ scales, // [N, 3]
+ const T *__restrict__ viewmats, // [C, 4, 4]
+ const T *__restrict__ Ks, // [C, 3, 3]
+ const int32_t image_width,
+ const int32_t image_height,
+ // fwd outputs
+ const int32_t *__restrict__ radii, // [C, N]
+ const T *__restrict__ ray_transforms, // [C, N, 3, 3]
+ // grad outputs
+ const T *__restrict__ v_means2d, // [C, N, 2]
+ const T *__restrict__ v_depths, // [C, N]
+ const T *__restrict__ v_normals, // [C, N, 3]
+ // grad inputs
+ T *__restrict__ v_ray_transforms, // [C, N, 3, 3]
+ T *__restrict__ v_means, // [N, 3]
+ T *__restrict__ v_quats, // [N, 4]
+ T *__restrict__ v_scales, // [N, 3]
+ T *__restrict__ v_viewmats // [C, 4, 4]
+) {
+ // parallelize over C * N.
+ uint32_t idx = cg::this_grid().thread_rank();
+ if (idx >= C * N || radii[idx] <= 0) {
+ return;
+ }
+ const uint32_t cid = idx / N; // camera id
+ const uint32_t gid = idx % N; // gaussian id
+
+ // shift pointers to the current camera and gaussian
+ means += gid * 3;
+ viewmats += cid * 16;
+ Ks += cid * 9;
+
+ ray_transforms += idx * 9;
+
+ v_means2d += idx * 2;
+ v_depths += idx;
+ v_normals += idx * 3;
+ v_ray_transforms += idx * 9;
+
+ // transform Gaussian to camera space
+ mat3 R = mat3(
+ viewmats[0],
+ viewmats[4],
+ viewmats[8], // 1st column
+ viewmats[1],
+ viewmats[5],
+ viewmats[9], // 2nd column
+ viewmats[2],
+ viewmats[6],
+ viewmats[10] // 3rd column
+ );
+ vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]);
+ vec3 mean_c;
+ pos_world_to_cam(R, t, glm::make_vec3(means), mean_c);
+
+ vec4 quat = glm::make_vec4(quats + gid * 4);
+ vec2 scale = glm::make_vec2(scales + gid * 3);
+
+ mat3 P = mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0);
+
+ mat3 _v_ray_transforms = mat3(
+ v_ray_transforms[0],
+ v_ray_transforms[1],
+ v_ray_transforms[2],
+ v_ray_transforms[3],
+ v_ray_transforms[4],
+ v_ray_transforms[5],
+ v_ray_transforms[6],
+ v_ray_transforms[7],
+ v_ray_transforms[8]
+ );
+
+ _v_ray_transforms[2][2] += v_depths[0];
+
+ vec3 v_normal = glm::make_vec3(v_normals);
+
+ vec3 v_mean(0.f);
+ vec2 v_scale(0.f);
+ vec4 v_quat(0.f);
+ compute_ray_transforms_aabb_vjp(
+ ray_transforms,
+ v_means2d,
+ v_normal,
+ R,
+ P,
+ t,
+ mean_c,
+ quat,
+ scale,
+ _v_ray_transforms,
+ v_quat,
+ v_scale,
+ v_mean
+ );
+
+ // #if __CUDA_ARCH__ >= 700
+ // write out results with warp-level reduction
+ auto warp = cg::tiled_partition<32>(cg::this_thread_block());
+ auto warp_group_g = cg::labeled_partition(warp, gid);
+ if (v_means != nullptr) {
+ warpSum(v_mean, warp_group_g);
+ if (warp_group_g.thread_rank() == 0) {
+ v_means += gid * 3;
+ GSPLAT_PRAGMA_UNROLL
+ for (uint32_t i = 0; i < 3; i++) {
+ gpuAtomicAdd(v_means + i, v_mean[i]);
+ }
+ }
+ }
+
+ // Directly output gradients w.r.t. the quaternion and scale
+ warpSum(v_quat, warp_group_g);
+ warpSum(v_scale, warp_group_g);
+ if (warp_group_g.thread_rank() == 0) {
+ v_quats += gid * 4;
+ v_scales += gid * 3;
+ gpuAtomicAdd(v_quats, v_quat[0]);
+ gpuAtomicAdd(v_quats + 1, v_quat[1]);
+ gpuAtomicAdd(v_quats + 2, v_quat[2]);
+ gpuAtomicAdd(v_quats + 3, v_quat[3]);
+ gpuAtomicAdd(v_scales, v_scale[0]);
+ gpuAtomicAdd(v_scales + 1, v_scale[1]);
+ }
+}
+
+std::tuple
+fully_fused_projection_bwd_2dgs_tensor(
+ // fwd inputs
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 2]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ // fwd outputs
+ const torch::Tensor &radii, // [C, N]
+ const torch::Tensor &ray_transforms, // [C, N, 3, 3]
+ // grad outputs
+ const torch::Tensor &v_means2d, // [C, N, 2]
+ const torch::Tensor &v_depths, // [C, N]
+ const torch::Tensor &v_normals, // [C, N, 3]
+ const torch::Tensor &v_ray_transforms, // [C, N, 3, 3]
+ const bool viewmats_requires_grad
+) {
+ GSPLAT_DEVICE_GUARD(means);
+ GSPLAT_CHECK_INPUT(means);
+ GSPLAT_CHECK_INPUT(quats);
+ GSPLAT_CHECK_INPUT(scales);
+ GSPLAT_CHECK_INPUT(viewmats);
+ GSPLAT_CHECK_INPUT(Ks);
+ GSPLAT_CHECK_INPUT(radii);
+ GSPLAT_CHECK_INPUT(ray_transforms);
+ GSPLAT_CHECK_INPUT(v_means2d);
+ GSPLAT_CHECK_INPUT(v_depths);
+ GSPLAT_CHECK_INPUT(v_normals);
+ GSPLAT_CHECK_INPUT(v_ray_transforms);
+
+ uint32_t N = means.size(0); // number of gaussians
+ uint32_t C = viewmats.size(0); // number of cameras
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
+
+ torch::Tensor v_means = torch::zeros_like(means);
+ torch::Tensor v_quats = torch::zeros_like(quats);
+ torch::Tensor v_scales = torch::zeros_like(scales);
+ torch::Tensor v_viewmats;
+ if (viewmats_requires_grad) {
+ v_viewmats = torch::zeros_like(viewmats);
+ }
+ if (C && N) {
+ fully_fused_projection_bwd_2dgs_kernel
+ <<<(C * N + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS,
+ GSPLAT_N_THREADS,
+ 0,
+ stream>>>(
+ C,
+ N,
+ means.data_ptr(),
+ quats.data_ptr(),
+ scales.data_ptr(),
+ viewmats.data_ptr(),
+ Ks.data_ptr(),
+ image_width,
+ image_height,
+ radii.data_ptr(),
+ ray_transforms.data_ptr(),
+ v_means2d.data_ptr(),
+ v_depths.data_ptr(),
+ v_normals.data_ptr(),
+ v_ray_transforms.data_ptr(),
+ v_means.data_ptr(),
+ v_quats.data_ptr(),
+ v_scales.data_ptr(),
+ viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr
+ );
+ }
+ return std::make_tuple(v_means, v_quats, v_scales, v_viewmats);
+}
+
+} // namespace gsplat
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu
new file mode 100644
index 000000000..08726d0f5
--- /dev/null
+++ b/gsplat/cuda/csrc/fully_fused_projection_2dgs_fwd.cu
@@ -0,0 +1,210 @@
+#include "bindings.h"
+#include "helpers.cuh"
+#include "utils.cuh"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Projection of Gaussians (Single Batch) Forward Pass 2DGS
+ ****************************************************************************/
+
+template
+__global__ void fully_fused_projection_fwd_2dgs_kernel(
+ const uint32_t C,
+ const uint32_t N,
+ const T *__restrict__ means, // [N, 3]
+ const T *__restrict__ quats, // [N, 4]
+ const T *__restrict__ scales, // [N, 3]
+ const T *__restrict__ viewmats, // [C, 4, 4]
+ const T *__restrict__ Ks, // [C, 3, 3]
+ const int32_t image_width,
+ const int32_t image_height,
+ const T near_plane,
+ const T far_plane,
+ const T radius_clip,
+ // outputs
+ int32_t *__restrict__ radii, // [C, N]
+ T *__restrict__ means2d, // [C, N, 2]
+ T *__restrict__ depths, // [C, N]
+ T *__restrict__ ray_transforms, // [C, N, 3, 3]
+ T *__restrict__ normals // [C, N, 3]
+) {
+ // parallelize over C * N.
+ uint32_t idx = cg::this_grid().thread_rank();
+ if (idx >= C * N) {
+ return;
+ }
+ const uint32_t cid = idx / N; // camera id
+ const uint32_t gid = idx % N; // gaussian id
+
+ // shift pointers to the current camera and gaussian
+ means += gid * 3;
+ viewmats += cid * 16;
+ Ks += cid * 9;
+
+ // glm is column-major but input is row-major
+ mat3 R = mat3(
+ viewmats[0],
+ viewmats[4],
+ viewmats[8], // 1st column
+ viewmats[1],
+ viewmats[5],
+ viewmats[9], // 2nd column
+ viewmats[2],
+ viewmats[6],
+ viewmats[10] // 3rd column
+ );
+ vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]);
+
+ // transform Gaussian center to camera space
+ vec3 mean_c;
+ pos_world_to_cam(R, t, glm::make_vec3(means), mean_c);
+ if (mean_c.z < near_plane || mean_c.z > far_plane) {
+ radii[idx] = 0;
+ return;
+ }
+
+ // build ray transformation matrix and transform from world space to camera
+ // space
+ quats += gid * 4;
+ scales += gid * 3;
+
+ mat3 RS_camera =
+ R * quat_to_rotmat(glm::make_vec4(quats)) *
+ mat3(scales[0], 0.0, 0.0, 0.0, scales[1], 0.0, 0.0, 0.0, 1.0);
+
+ mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c);
+
+ mat3 world_2_pix =
+ mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0);
+ mat3 M = glm::transpose(WH) * world_2_pix;
+
+ // compute AABB
+ const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]);
+ const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]);
+ const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]);
+
+ const vec3 temp_point = vec3(1.0f, 1.0f, -1.0f);
+ const T distance = sum(temp_point * M2 * M2);
+
+ if (distance == 0.0f)
+ return;
+
+ const vec3 f = (1 / distance) * temp_point;
+ const vec2 mean2d = vec2(sum(f * M0 * M2), sum(f * M1 * M2));
+
+ const vec2 temp = {sum(f * M0 * M0), sum(f * M1 * M1)};
+ const vec2 half_extend = mean2d * mean2d - temp;
+ const T radius =
+ ceil(3.f * sqrt(max(1e-4, max(half_extend.x, half_extend.y))));
+
+ if (radius <= radius_clip) {
+ radii[idx] = 0;
+ return;
+ }
+
+ // mask out gaussians outside the image region
+ if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width ||
+ mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) {
+ radii[idx] = 0;
+ return;
+ }
+
+ // normals dual visible
+ vec3 normal = RS_camera[2];
+ T multipler = glm::dot(-normal, mean_c) > 0 ? 1 : -1;
+ normal *= multipler;
+
+ // write to outputs
+ radii[idx] = (int32_t)radius;
+ means2d[idx * 2] = mean2d.x;
+ means2d[idx * 2 + 1] = mean2d.y;
+ depths[idx] = mean_c.z;
+ ray_transforms[idx * 9] = M0.x;
+ ray_transforms[idx * 9 + 1] = M0.y;
+ ray_transforms[idx * 9 + 2] = M0.z;
+ ray_transforms[idx * 9 + 3] = M1.x;
+ ray_transforms[idx * 9 + 4] = M1.y;
+ ray_transforms[idx * 9 + 5] = M1.z;
+ ray_transforms[idx * 9 + 6] = M2.x;
+ ray_transforms[idx * 9 + 7] = M2.y;
+ ray_transforms[idx * 9 + 8] = M2.z;
+ normals[idx * 3] = normal.x;
+ normals[idx * 3 + 1] = normal.y;
+ normals[idx * 3 + 2] = normal.z;
+}
+
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+fully_fused_projection_fwd_2dgs_tensor(
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const float eps2d,
+ const float near_plane,
+ const float far_plane,
+ const float radius_clip
+) {
+ GSPLAT_DEVICE_GUARD(means);
+ GSPLAT_CHECK_INPUT(means);
+ GSPLAT_CHECK_INPUT(quats);
+ GSPLAT_CHECK_INPUT(scales);
+ GSPLAT_CHECK_INPUT(viewmats);
+ GSPLAT_CHECK_INPUT(Ks);
+
+ uint32_t N = means.size(0); // number of gaussians
+ uint32_t C = viewmats.size(0); // number of cameras
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
+
+ torch::Tensor radii =
+ torch::empty({C, N}, means.options().dtype(torch::kInt32));
+ torch::Tensor means2d = torch::empty({C, N, 2}, means.options());
+ torch::Tensor depths = torch::empty({C, N}, means.options());
+ torch::Tensor ray_transforms = torch::empty({C, N, 3, 3}, means.options());
+ torch::Tensor normals = torch::empty({C, N, 3}, means.options());
+
+ if (C && N) {
+ fully_fused_projection_fwd_2dgs_kernel
+ <<<(C * N + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS,
+ GSPLAT_N_THREADS,
+ 0,
+ stream>>>(
+ C,
+ N,
+ means.data_ptr(),
+ quats.data_ptr(),
+ scales.data_ptr(),
+ viewmats.data_ptr(),
+ Ks.data_ptr(),
+ image_width,
+ image_height,
+ near_plane,
+ far_plane,
+ radius_clip,
+ radii.data_ptr(),
+ means2d.data_ptr(),
+ depths.data_ptr(),
+ ray_transforms.data_ptr(),
+ normals.data_ptr()
+ );
+ }
+ return std::make_tuple(radii, means2d, depths, ray_transforms, normals);
+}
+
+} // namespace gsplat
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu
new file mode 100644
index 000000000..564eb70bd
--- /dev/null
+++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_bwd.cu
@@ -0,0 +1,270 @@
+#include "bindings.h"
+#include "helpers.cuh"
+#include "utils.cuh"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Projection of Gaussians (Batched) Backward Pass 2DGS
+ ****************************************************************************/
+
+template
+__global__ void fully_fused_projection_packed_bwd_2dgs_kernel(
+ // fwd inputs
+ const uint32_t C,
+ const uint32_t N,
+ const uint32_t nnz,
+ const T *__restrict__ means, // [N, 3]
+ const T *__restrict__ quats, // [N, 4]
+ const T *__restrict__ scales, // [N, 3]
+ const T *__restrict__ viewmats, // [C, 4, 4]
+ const T *__restrict__ Ks, // [C, 3, 3]
+ const int32_t image_width,
+ const int32_t image_height,
+ // fwd outputs
+ const int64_t *__restrict__ camera_ids, // [nnz]
+ const int64_t *__restrict__ gaussian_ids, // [nnz]
+ const T *__restrict__ ray_transforms, // [nnz, 3]
+ // grad outputs
+ const T *__restrict__ v_means2d, // [nnz, 2]
+ const T *__restrict__ v_depths, // [nnz]
+ const T *__restrict__ v_normals, // [nnz, 3]
+ const bool sparse_grad, // whether the outputs are in COO format [nnz, ...]
+ // grad inputs
+ T *__restrict__ v_ray_transforms,
+ T *__restrict__ v_means, // [N, 3] or [nnz, 3]
+ T *__restrict__ v_quats, // [N, 4] or [nnz, 4] Optional
+ T *__restrict__ v_scales, // [N, 3] or [nnz, 3] Optional
+ T *__restrict__ v_viewmats // [C, 4, 4] Optional
+) {
+ // parallelize over nnz.
+ uint32_t idx = cg::this_grid().thread_rank();
+ if (idx >= nnz) {
+ return;
+ }
+ const int64_t cid = camera_ids[idx]; // camera id
+ const int64_t gid = gaussian_ids[idx]; // gaussian id
+
+ // shift pointers to the current camera and gaussian
+ means += gid * 3;
+ viewmats += cid * 16;
+ Ks += cid * 9;
+
+ ray_transforms += idx * 9;
+
+ v_means2d += idx * 2;
+ v_normals += idx * 3;
+ v_depths += idx;
+ v_ray_transforms += idx * 9;
+
+ // transform Gaussian to camera space
+ mat3 R = mat3(
+ viewmats[0],
+ viewmats[4],
+ viewmats[8], // 1st column
+ viewmats[1],
+ viewmats[5],
+ viewmats[9], // 2nd column
+ viewmats[2],
+ viewmats[6],
+ viewmats[10] // 3rd column
+ );
+ vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]);
+ vec3 mean_c;
+ pos_world_to_cam(R, t, glm::make_vec3(means), mean_c);
+
+ vec4 quat = glm::make_vec4(quats + gid * 4);
+ vec2 scale = glm::make_vec2(scales + gid * 3);
+ mat3 P = mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0);
+
+ mat3 _v_ray_transforms = mat3(
+ v_ray_transforms[0],
+ v_ray_transforms[1],
+ v_ray_transforms[2],
+ v_ray_transforms[3],
+ v_ray_transforms[4],
+ v_ray_transforms[5],
+ v_ray_transforms[6],
+ v_ray_transforms[7],
+ v_ray_transforms[8]
+ );
+
+ _v_ray_transforms[2][2] += v_depths[0];
+
+ vec3 v_normal = glm::make_vec3(v_normals);
+
+ vec3 v_mean(0.f);
+ vec2 v_scale(0.f);
+ vec4 v_quat(0.f);
+ compute_ray_transforms_aabb_vjp(
+ ray_transforms,
+ v_means2d,
+ v_normal,
+ R,
+ P,
+ t,
+ mean_c,
+ quat,
+ scale,
+ _v_ray_transforms,
+ v_quat,
+ v_scale,
+ v_mean
+ );
+
+ auto warp = cg::tiled_partition<32>(cg::this_thread_block());
+ if (sparse_grad) {
+ // write out results with sparse layout
+ if (v_means != nullptr) {
+ v_means += idx * 3;
+ GSPLAT_PRAGMA_UNROLL
+ for (uint32_t i = 0; i < 3; i++) {
+ v_means[i] = v_mean[i];
+ }
+ }
+ v_quats += idx * 4;
+ v_scales += idx * 3;
+ v_quats[0] = v_quat[0];
+ v_quats[1] = v_quat[1];
+ v_quats[2] = v_quat[2];
+ v_quats[3] = v_quat[3];
+ v_scales[0] = v_scale[0];
+ v_scales[1] = v_scale[1];
+ } else {
+ // write out results with dense layout
+ // #if __CUDA_ARCH__ >= 700
+ // write out results with warp-level reduction
+ auto warp_group_g = cg::labeled_partition(warp, gid);
+ if (v_means != nullptr) {
+ warpSum(v_mean, warp_group_g);
+ if (warp_group_g.thread_rank() == 0) {
+ v_means += gid * 3;
+ GSPLAT_PRAGMA_UNROLL
+ for (uint32_t i = 0; i < 3; i++) {
+ gpuAtomicAdd(v_means + i, v_mean[i]);
+ }
+ }
+ }
+ // Directly output gradients w.r.t. the quaternion and scale
+ warpSum(v_quat, warp_group_g);
+ warpSum(v_scale, warp_group_g);
+ if (warp_group_g.thread_rank() == 0) {
+ v_quats += gid * 4;
+ v_scales += gid * 3;
+ gpuAtomicAdd(v_quats, v_quat[0]);
+ gpuAtomicAdd(v_quats + 1, v_quat[1]);
+ gpuAtomicAdd(v_quats + 2, v_quat[2]);
+ gpuAtomicAdd(v_quats + 3, v_quat[3]);
+ gpuAtomicAdd(v_scales, v_scale[0]);
+ gpuAtomicAdd(v_scales + 1, v_scale[1]);
+ }
+ }
+}
+
+std::tuple
+fully_fused_projection_packed_bwd_2dgs_tensor(
+ // fwd inputs
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 4]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ // fwd outputs
+ const torch::Tensor &camera_ids, // [nnz]
+ const torch::Tensor &gaussian_ids, // [nnz]
+ const torch::Tensor &ray_transforms, // [nnz, 3, 3]
+ // grad outputs
+ const torch::Tensor &v_means2d, // [nnz, 2]
+ const torch::Tensor &v_depths, // [nnz]
+ const torch::Tensor &v_ray_transforms, // [nnz, 3, 3]
+ const torch::Tensor &v_normals, // [nnz, 3]
+ const bool viewmats_requires_grad,
+ const bool sparse_grad
+) {
+
+ GSPLAT_DEVICE_GUARD(means);
+ GSPLAT_CHECK_INPUT(means);
+ GSPLAT_CHECK_INPUT(quats);
+ GSPLAT_CHECK_INPUT(scales);
+ GSPLAT_CHECK_INPUT(viewmats);
+ GSPLAT_CHECK_INPUT(Ks);
+ GSPLAT_CHECK_INPUT(camera_ids);
+ GSPLAT_CHECK_INPUT(gaussian_ids);
+ GSPLAT_CHECK_INPUT(ray_transforms);
+ GSPLAT_CHECK_INPUT(v_means2d);
+ GSPLAT_CHECK_INPUT(v_depths);
+ GSPLAT_CHECK_INPUT(v_normals);
+ GSPLAT_CHECK_INPUT(v_ray_transforms);
+
+ uint32_t N = means.size(0); // number of gaussians
+ uint32_t C = viewmats.size(0); // number of cameras
+ uint32_t nnz = camera_ids.size(0);
+
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
+
+ torch::Tensor v_means, v_quats, v_scales, v_viewmats;
+ if (sparse_grad) {
+ v_means = torch::zeros({nnz, 3}, means.options());
+
+ v_quats = torch::zeros({nnz, 4}, quats.options());
+ v_scales = torch::zeros({nnz, 3}, scales.options());
+
+ if (viewmats_requires_grad) {
+ v_viewmats = torch::zeros({C, 4, 4}, viewmats.options());
+ }
+
+ } else {
+ v_means = torch::zeros_like(means);
+
+ v_quats = torch::zeros_like(quats);
+ v_scales = torch::zeros_like(scales);
+
+ if (viewmats_requires_grad) {
+ v_viewmats = torch::zeros_like(viewmats);
+ }
+ }
+ if (nnz) {
+
+ fully_fused_projection_packed_bwd_2dgs_kernel
+ <<<(nnz + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS,
+ GSPLAT_N_THREADS,
+ 0,
+ stream>>>(
+ C,
+ N,
+ nnz,
+ means.data_ptr(),
+ quats.data_ptr(),
+ scales.data_ptr(),
+ viewmats.data_ptr(),
+ Ks.data_ptr(),
+ image_width,
+ image_height,
+ camera_ids.data_ptr(),
+ gaussian_ids.data_ptr(),
+ ray_transforms.data_ptr(),
+ v_means2d.data_ptr(),
+ v_depths.data_ptr(),
+ v_normals.data_ptr(),
+ sparse_grad,
+ v_ray_transforms.data_ptr(),
+ v_means.data_ptr(),
+ v_quats.data_ptr(),
+ v_scales.data_ptr(),
+ viewmats_requires_grad ? v_viewmats.data_ptr() : nullptr
+ );
+ }
+ return std::make_tuple(v_means, v_quats, v_scales, v_viewmats);
+}
+
+} // namespace gsplat
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu
new file mode 100644
index 000000000..34ac796fb
--- /dev/null
+++ b/gsplat/cuda/csrc/fully_fused_projection_packed_2dgs_fwd.cu
@@ -0,0 +1,328 @@
+
+#include "bindings.h"
+#include "helpers.cuh"
+#include "utils.cuh"
+
+#include
+#include
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Projection of Gaussians (Batched) Forward Pass 2DGS
+ ****************************************************************************/
+
+template
+__global__ void fully_fused_projection_packed_fwd_2dgs_kernel(
+ const uint32_t C,
+ const uint32_t N,
+ const T *__restrict__ means, // [N, 3]
+ const T *__restrict__ quats, // [N, 4]
+ const T *__restrict__ scales, // [N, 3]
+ const T *__restrict__ viewmats, // [C, 4, 4]
+ const T *__restrict__ Ks, // [C, 3, 3]
+ const int32_t image_width,
+ const int32_t image_height,
+ const T near_plane,
+ const T far_plane,
+ const T radius_clip,
+ const int32_t
+ *__restrict__ block_accum, // [C * blocks_per_row] packing helper
+ int32_t *__restrict__ block_cnts, // [C * blocks_per_row] packing helper
+ // outputs
+ int32_t *__restrict__ indptr, // [C + 1]
+ int64_t *__restrict__ camera_ids, // [nnz]
+ int64_t *__restrict__ gaussian_ids, // [nnz]
+ int32_t *__restrict__ radii, // [nnz]
+ T *__restrict__ means2d, // [nnz, 2]
+ T *__restrict__ depths, // [nnz]
+ T *__restrict__ ray_transforms, // [nnz, 3, 3]
+ T *__restrict__ normals // [nnz, 3]
+) {
+ int32_t blocks_per_row = gridDim.x;
+
+ int32_t row_idx = blockIdx.y; // cid
+ int32_t block_col_idx = blockIdx.x;
+ int32_t block_idx = row_idx * blocks_per_row + block_col_idx;
+
+ int32_t col_idx = block_col_idx * blockDim.x + threadIdx.x; // gid
+
+ bool valid = (row_idx < C) && (col_idx < N);
+
+ // check if points are with camera near and far plane
+ vec3 mean_c;
+ mat3 R;
+ if (valid) {
+ // shift pointers to the current camera and gaussian
+ means += col_idx * 3;
+ viewmats += row_idx * 16;
+
+ // glm is column-major but input is row-major
+ R = mat3(
+ viewmats[0],
+ viewmats[4],
+ viewmats[8], // 1st column
+ viewmats[1],
+ viewmats[5],
+ viewmats[9], // 2nd column
+ viewmats[2],
+ viewmats[6],
+ viewmats[10] // 3rd column
+ );
+ vec3 t = vec3(viewmats[3], viewmats[7], viewmats[11]);
+
+ // transform Gaussian center to camera space
+ pos_world_to_cam(R, t, glm::make_vec3(means), mean_c);
+ if (mean_c.z < near_plane || mean_c.z > far_plane) {
+ valid = false;
+ }
+ }
+
+ vec2 mean2d;
+ mat3 M;
+ T radius;
+ vec3 normal;
+ if (valid) {
+ // build ray transformation matrix and transform from world space to
+ // camera space
+ quats += col_idx * 4;
+ scales += col_idx * 3;
+
+ mat3 RS_camera =
+ R * quat_to_rotmat(glm::make_vec4(quats)) *
+ mat3(scales[0], 0.0, 0.0, 0.0, scales[1], 0.0, 0.0, 0.0, 1.0);
+ ;
+ mat3 WH = mat3(RS_camera[0], RS_camera[1], mean_c);
+
+ mat3 world_2_pix =
+ mat3(Ks[0], 0.0, Ks[2], 0.0, Ks[4], Ks[5], 0.0, 0.0, 1.0);
+ M = glm::transpose(WH) * world_2_pix;
+
+ // compute AABB
+ const vec3 M0 = vec3(M[0][0], M[0][1], M[0][2]);
+ const vec3 M1 = vec3(M[1][0], M[1][1], M[1][2]);
+ const vec3 M2 = vec3(M[2][0], M[2][1], M[2][2]);
+
+ const vec3 temp_point = vec3(1.0f, 1.0f, -1.0f);
+ const T distance = sum(temp_point * M2 * M2);
+
+ if (distance == 0.0f)
+ valid = false;
+
+ const vec3 f = (1 / distance) * temp_point;
+ mean2d = vec2(sum(f * M0 * M2), sum(f * M1 * M2));
+
+ const vec2 temp = {sum(f * M0 * M0), sum(f * M1 * M1)};
+ const vec2 half_extend = mean2d * mean2d - temp;
+ radius = ceil(3.f * sqrt(max(1e-4, max(half_extend.x, half_extend.y))));
+
+ if (radius <= radius_clip) {
+ valid = false;
+ }
+
+ // mask out gaussians outside the image region
+ if (mean2d.x + radius <= 0 || mean2d.x - radius >= image_width ||
+ mean2d.y + radius <= 0 || mean2d.y - radius >= image_height) {
+ valid = false;
+ }
+
+ // normal dual visible
+ normal = RS_camera[2];
+ T multipler = glm::dot(-normal, mean_c) > 0 ? 1 : -1;
+ normal *= multipler;
+ }
+
+ int32_t thread_data = static_cast(valid);
+ if (block_cnts != nullptr) {
+ // First pass: compute the block-wide sum
+ int32_t aggregate;
+ if (__syncthreads_or(thread_data)) {
+ typedef cub::BlockReduce BlockReduce;
+ __shared__ typename BlockReduce::TempStorage temp_storage;
+ aggregate = BlockReduce(temp_storage).Sum(thread_data);
+ } else {
+ aggregate = 0;
+ }
+ if (threadIdx.x == 0) {
+ block_cnts[block_idx] = aggregate;
+ }
+ } else {
+ // Second pass: write out the indices of the non zero elements
+ if (__syncthreads_or(thread_data)) {
+ typedef cub::BlockScan BlockScan;
+ __shared__ typename BlockScan::TempStorage temp_storage;
+ BlockScan(temp_storage).ExclusiveSum(thread_data, thread_data);
+ }
+ if (valid) {
+ if (block_idx > 0) {
+ int32_t offset = block_accum[block_idx - 1];
+ thread_data += offset;
+ }
+ // write to outputs
+ camera_ids[thread_data] = row_idx; // cid
+ gaussian_ids[thread_data] = col_idx; // gid
+ radii[thread_data] = (int32_t)radius;
+ means2d[thread_data * 2] = mean2d.x;
+ means2d[thread_data * 2 + 1] = mean2d.y;
+ depths[thread_data] = mean_c.z;
+ ray_transforms[thread_data * 9] = M[0][0];
+ ray_transforms[thread_data * 9 + 1] = M[0][1];
+ ray_transforms[thread_data * 9 + 2] = M[0][2];
+ ray_transforms[thread_data * 9 + 3] = M[1][0];
+ ray_transforms[thread_data * 9 + 4] = M[1][1];
+ ray_transforms[thread_data * 9 + 5] = M[1][2];
+ ray_transforms[thread_data * 9 + 6] = M[2][0];
+ ray_transforms[thread_data * 9 + 7] = M[2][1];
+ ray_transforms[thread_data * 9 + 8] = M[2][2];
+ normals[thread_data * 3] = normal.x;
+ normals[thread_data * 3 + 1] = normal.y;
+ normals[thread_data * 3 + 2] = normal.z;
+ }
+ // lane 0 of the first block in each row writes the indptr
+ if (threadIdx.x == 0 && block_col_idx == 0) {
+ if (row_idx == 0) {
+ indptr[0] = 0;
+ indptr[C] = block_accum[C * blocks_per_row - 1];
+ } else {
+ indptr[row_idx] = block_accum[block_idx - 1];
+ }
+ }
+ }
+}
+
+std::tuple<
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor,
+ torch::Tensor>
+fully_fused_projection_packed_fwd_2dgs_tensor(
+ const torch::Tensor &means, // [N, 3]
+ const torch::Tensor &quats, // [N, 3]
+ const torch::Tensor &scales, // [N, 3]
+ const torch::Tensor &viewmats, // [C, 4, 4]
+ const torch::Tensor &Ks, // [C, 3, 3]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const float near_plane,
+ const float far_plane,
+ const float radius_clip
+) {
+ GSPLAT_DEVICE_GUARD(means);
+ GSPLAT_CHECK_INPUT(means);
+ GSPLAT_CHECK_INPUT(quats);
+ GSPLAT_CHECK_INPUT(scales);
+ GSPLAT_CHECK_INPUT(viewmats);
+ GSPLAT_CHECK_INPUT(Ks);
+
+ uint32_t N = means.size(0); // number of gaussians
+ uint32_t C = viewmats.size(0); // number of cameras
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
+ auto opt = means.options().dtype(torch::kInt32);
+
+ uint32_t nrows = C;
+ uint32_t ncols = N;
+ uint32_t blocks_per_row = (ncols + GSPLAT_N_THREADS - 1) / GSPLAT_N_THREADS;
+
+ dim3 threads = {GSPLAT_N_THREADS, 1, 1};
+ // limit on the number of blocks: [2**31 - 1, 65535, 65535]
+ dim3 blocks = {blocks_per_row, nrows, 1};
+
+ // first pass
+ int32_t nnz;
+ torch::Tensor block_accum;
+ if (C && N) {
+ torch::Tensor block_cnts = torch::empty({nrows * blocks_per_row}, opt);
+ fully_fused_projection_packed_fwd_2dgs_kernel
+ <<>>(
+ C,
+ N,
+ means.data_ptr(),
+ quats.data_ptr(),
+ scales.data_ptr(),
+ viewmats.data_ptr(),
+ Ks.data_ptr(),
+ image_width,
+ image_height,
+ near_plane,
+ far_plane,
+ radius_clip,
+ nullptr,
+ block_cnts.data_ptr(),
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr
+ );
+ block_accum = torch::cumsum(block_cnts, 0, torch::kInt32);
+ nnz = block_accum[-1].item();
+ } else {
+ nnz = 0;
+ }
+
+ // second pass
+ torch::Tensor indptr = torch::empty({C + 1}, opt);
+ torch::Tensor camera_ids = torch::empty({nnz}, opt.dtype(torch::kInt64));
+ torch::Tensor gaussian_ids = torch::empty({nnz}, opt.dtype(torch::kInt64));
+ torch::Tensor radii =
+ torch::empty({nnz}, means.options().dtype(torch::kInt32));
+ torch::Tensor means2d = torch::empty({nnz, 2}, means.options());
+ torch::Tensor depths = torch::empty({nnz}, means.options());
+ torch::Tensor ray_transforms = torch::empty({nnz, 3, 3}, means.options());
+ torch::Tensor normals = torch::empty({nnz, 3}, means.options());
+
+ if (nnz) {
+ fully_fused_projection_packed_fwd_2dgs_kernel
+ <<>>(
+ C,
+ N,
+ means.data_ptr(),
+ quats.data_ptr(),
+ scales.data_ptr(),
+ viewmats.data_ptr(),
+ Ks.data_ptr(),
+ image_width,
+ image_height,
+ near_plane,
+ far_plane,
+ radius_clip,
+ block_accum.data_ptr(),
+ nullptr,
+ indptr.data_ptr(),
+ camera_ids.data_ptr(),
+ gaussian_ids.data_ptr(),
+ radii.data_ptr(),
+ means2d.data_ptr(),
+ depths.data_ptr(),
+ ray_transforms.data_ptr(),
+ normals.data_ptr()
+ );
+ } else {
+ indptr.fill_(0);
+ }
+
+ return std::make_tuple(
+ indptr,
+ camera_ids,
+ gaussian_ids,
+ radii,
+ means2d,
+ depths,
+ ray_transforms,
+ normals
+ );
+}
+
+} // namespace gsplat
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu
index 27a337a34..12f21c611 100644
--- a/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu
+++ b/gsplat/cuda/csrc/fully_fused_projection_packed_fwd.cu
@@ -1,4 +1,3 @@
-
#include "bindings.h"
#include "helpers.cuh"
#include "utils.cuh"
diff --git a/gsplat/cuda/csrc/helpers.cuh b/gsplat/cuda/csrc/helpers.cuh
index 0e4995bad..93ff6bd45 100644
--- a/gsplat/cuda/csrc/helpers.cuh
+++ b/gsplat/cuda/csrc/helpers.cuh
@@ -73,6 +73,10 @@ inline __device__ void warpMax(ScalarT &val, WarpT &warp) {
val = cg::reduce(warp, val, cg::greater());
}
+template __forceinline__ __device__ T sum(vec3 a) {
+ return a.x + a.y + a.z;
+}
+
} // namespace gsplat
-#endif // GSPLAT_CUDA_HELPERS_H
\ No newline at end of file
+#endif // GSPLAT_CUDA_HELPERS_H
diff --git a/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu
new file mode 100644
index 000000000..1432a1666
--- /dev/null
+++ b/gsplat/cuda/csrc/rasterize_to_indices_in_range_2dgs.cu
@@ -0,0 +1,339 @@
+#include "bindings.h"
+#include "helpers.cuh"
+#include "types.cuh"
+#include "utils.cuh"
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Rasterization to Indices in Range 2DGS
+ ****************************************************************************/
+
+template
+__global__ void rasterize_to_indices_in_range_kernel(
+ const uint32_t range_start,
+ const uint32_t range_end,
+ const uint32_t C,
+ const uint32_t N,
+ const uint32_t n_isects,
+ const vec2 *__restrict__ means2d, // [C, N, 2]
+ const T *__restrict__ ray_transforms, // [C, N, 3, 3]
+ const T *__restrict__ opacities, // [C, N]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ const uint32_t tile_width,
+ const uint32_t tile_height,
+ const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width]
+ const int32_t *__restrict__ flatten_ids, // [n_isects]
+ const T *__restrict__ transmittances, // [C, image_height, image_width]
+ const int32_t *__restrict__ chunk_starts, // [C, image_height, image_width]
+ int32_t *__restrict__ chunk_cnts, // [C, image_height, image_width]
+ int64_t *__restrict__ gaussian_ids, // [n_elems]
+ int64_t *__restrict__ pixel_ids // [n_elems]
+) {
+ // each thread draws one pixel, but also timeshares caching gaussians in a
+ // shared tile
+
+ auto block = cg::this_thread_block();
+ uint32_t camera_id = block.group_index().x;
+ uint32_t tile_id =
+ block.group_index().y * tile_width + block.group_index().z;
+ uint32_t i = block.group_index().y * tile_size + block.thread_index().y;
+ uint32_t j = block.group_index().z * tile_size + block.thread_index().x;
+
+ // move pointers to the current camera
+ tile_offsets += camera_id * tile_height * tile_width;
+ transmittances += camera_id * image_height * image_width;
+
+ T px = (T)j + 0.5f;
+ T py = (T)i + 0.5f;
+ int32_t pix_id = i * image_width + j;
+
+ // return if out of bounds
+ // keep not rasterizing threads around for reading data
+ bool inside = (i < image_height && j < image_width);
+ bool done = !inside;
+
+ bool first_pass = chunk_starts == nullptr;
+ int32_t base;
+ if (!first_pass && inside) {
+ chunk_starts += camera_id * image_height * image_width;
+ base = chunk_starts[pix_id];
+ }
+
+ // have all threads in tile process the same gaussians in batches
+ // first collect gaussians between range.x and range.y in batches
+ // which gaussians to look through in this tile
+ int32_t isect_range_start = tile_offsets[tile_id];
+ int32_t isect_range_end =
+ (camera_id == C - 1) && (tile_id == tile_width * tile_height - 1)
+ ? n_isects
+ : tile_offsets[tile_id + 1];
+ const uint32_t block_size = block.size();
+ uint32_t num_batches =
+ (isect_range_end - isect_range_start + block_size - 1) / block_size;
+
+ if (range_start >= num_batches) {
+ // this entire tile has been processed in the previous iterations
+ // so we don't need to do anything.
+ return;
+ }
+
+ extern __shared__ int s[];
+ int32_t *id_batch = (int32_t *)s; // [block_size]
+ vec3 *xy_opacity_batch =
+ reinterpret_cast *>(&id_batch[block_size]); // [block_size]
+ vec3 *u_Ms_batch =
+ reinterpret_cast *>(&xy_opacity_batch[block_size]
+ ); // [block_size]
+ vec3 *v_Ms_batch =
+ reinterpret_cast *>(&u_Ms_batch[block_size]
+ ); // [block_size]
+ vec3 *w_Ms_batch =
+ reinterpret_cast *>(&v_Ms_batch[block_size]
+ ); // [block_size]
+
+ // current visibility left to render
+ // transmittance is gonna be used in the backward pass which requires a high
+ // numerical precision so we (should) use double for it. However double make
+ // bwd 1.5x slower so we stick with float for now.
+ T trans, next_trans;
+ if (inside) {
+ trans = transmittances[pix_id];
+ next_trans = trans;
+ }
+
+ // collect and process batches of gaussians
+ // each thread loads one gaussian at a time before rasterizing its
+ // designated pixel
+ uint32_t tr = block.thread_rank();
+
+ int32_t cnt = 0;
+ for (uint32_t b = range_start; b < min(range_end, num_batches); ++b) {
+ // resync all threads before beginning next batch
+ // end early if entire tile is done
+ if (__syncthreads_count(done) >= block_size) {
+ break;
+ }
+
+ // each thread fetch 1 gaussian from front to back
+ // index of gaussian to load
+ uint32_t batch_start = isect_range_start + block_size * b;
+ uint32_t idx = batch_start + tr;
+ if (idx < isect_range_end) {
+ int32_t g = flatten_ids[idx];
+ id_batch[tr] = g;
+ const vec2 xy = means2d[g];
+ const T opac = opacities[g];
+ xy_opacity_batch[tr] = {xy.x, xy.y, opac};
+ u_Ms_batch[tr] = {
+ ray_transforms[g * 9], ray_transforms[g * 9 + 1], ray_transforms[g * 9 + 2]
+ };
+ v_Ms_batch[tr] = {
+ ray_transforms[g * 9 + 3], ray_transforms[g * 9 + 4], ray_transforms[g * 9 + 5]
+ };
+ w_Ms_batch[tr] = {
+ ray_transforms[g * 9 + 6], ray_transforms[g * 9 + 7], ray_transforms[g * 9 + 8]
+ };
+ }
+
+ // wait for other threads to collect the gaussians in batch
+ block.sync();
+
+ // process gaussians in the current batch for this pixel
+ uint32_t batch_size = min(block_size, isect_range_end - batch_start);
+ for (uint32_t t = 0; (t < batch_size) && !done; ++t) {
+ const vec3 u_M = u_Ms_batch[t];
+ const vec3 v_M = v_Ms_batch[t];
+ const vec3 w_M = w_Ms_batch[t];
+ const vec3 xy_opac = xy_opacity_batch[t];
+ const T opac = xy_opac.z;
+
+ const vec3 h_u = px * w_M - u_M;
+ const vec3 h_v = py * w_M - v_M;
+
+ const vec3 ray_cross = glm::cross(h_u, h_v);
+
+ if (ray_cross.z == 0.0)
+ continue;
+
+ const vec2 s = {
+ ray_cross.x / ray_cross.z, ray_cross.y / ray_cross.z
+ };
+ const T gauss_weight_3d = s.x * s.x + s.y * s.y;
+
+ // Low pass filter
+ const vec2 d = {xy_opac.x - px, xy_opac.y - py};
+ // 2D screen distance
+ const T gauss_weight_2d =
+ FILTER_INV_SQUARE * (d.x * d.x + d.y * d.y);
+ const T gauss_weight = min(gauss_weight_3d, gauss_weight_2d);
+
+ const T sigma = 0.5f * gauss_weight;
+ T alpha = min(0.999f, opac * __expf(-sigma));
+
+ if (sigma < 0.f || alpha < 1.f / 255.f) {
+ continue;
+ }
+
+ next_trans = trans * (1.0f - alpha);
+ if (next_trans <= 1e-4) { // this pixel is done: exclusive
+ done = true;
+ break;
+ }
+
+ if (first_pass) {
+ // First pass of this function we count the number of gaussians
+ // that contribute to each pixel
+ cnt += 1;
+ } else {
+ // Second pass we write out the gaussian ids and pixel ids
+ int32_t g = id_batch[t]; // flatten index in [C * N]
+ gaussian_ids[base + cnt] = g % N;
+ pixel_ids[base + cnt] =
+ pix_id + camera_id * image_height * image_width;
+ cnt += 1;
+ }
+
+ trans = next_trans;
+ }
+ }
+
+ if (inside && first_pass) {
+ chunk_cnts += camera_id * image_height * image_width;
+ chunk_cnts[pix_id] = cnt;
+ }
+}
+
+std::tuple
+rasterize_to_indices_in_range_2dgs_tensor(
+ const uint32_t range_start,
+ const uint32_t range_end, // iteration steps
+ const torch::Tensor transmittances, // [C, image_height, image_width]
+ // Gaussian parameters
+ const torch::Tensor &means2d, // [C, N, 2]
+ const torch::Tensor &ray_transforms, // [C, N, 3, 3]
+ const torch::Tensor &opacities, // [C, N]
+ // image size
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ // intersections
+ const torch::Tensor &tile_offsets, // [C, tile_height, tile_width]
+ const torch::Tensor &flatten_ids // [n_isects]
+) {
+ GSPLAT_DEVICE_GUARD(means2d);
+ GSPLAT_CHECK_INPUT(means2d);
+ GSPLAT_CHECK_INPUT(ray_transforms);
+ GSPLAT_CHECK_INPUT(opacities);
+ GSPLAT_CHECK_INPUT(tile_offsets);
+ GSPLAT_CHECK_INPUT(flatten_ids);
+
+ uint32_t C = means2d.size(0); // number of cameras
+ uint32_t N = means2d.size(1); // number of gaussians
+ uint32_t tile_height = tile_offsets.size(1);
+ uint32_t tile_width = tile_offsets.size(2);
+ uint32_t n_isects = flatten_ids.size(0);
+
+ // Each block covers a tile on the image. In total there are
+ // C * tile_height * tile_width blocks.
+ dim3 threads = {tile_size, tile_size, 1};
+ dim3 blocks = {C, tile_height, tile_width};
+
+ at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
+ const uint32_t shared_mem =
+ tile_size * tile_size *
+ (sizeof(int32_t) + sizeof(vec3) + sizeof(vec3) +
+ sizeof(vec3) + sizeof(vec3));
+ if (cudaFuncSetAttribute(
+ rasterize_to_indices_in_range_kernel,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ shared_mem
+ ) != cudaSuccess) {
+ AT_ERROR(
+ "Failed to set maximum shared memory size (requested ",
+ shared_mem,
+ " bytes), try lowering tile_size."
+ );
+ }
+
+ // First pass: count the number of gaussians that contribute to each pixel
+ int64_t n_elems;
+ torch::Tensor chunk_starts;
+ if (n_isects) {
+ torch::Tensor chunk_cnts = torch::zeros(
+ {C * image_height * image_width},
+ means2d.options().dtype(torch::kInt32)
+ );
+ rasterize_to_indices_in_range_kernel
+ <<>>(
+ range_start,
+ range_end,
+ C,
+ N,
+ n_isects,
+ reinterpret_cast *>(means2d.data_ptr()),
+ ray_transforms.data_ptr(),
+ opacities.data_ptr(),
+ image_width,
+ image_height,
+ tile_size,
+ tile_width,
+ tile_height,
+ tile_offsets.data_ptr(),
+ flatten_ids.data_ptr(),
+ transmittances.data_ptr(),
+ nullptr,
+ chunk_cnts.data_ptr(),
+ nullptr,
+ nullptr
+ );
+
+ torch::Tensor cumsum =
+ torch::cumsum(chunk_cnts, 0, chunk_cnts.scalar_type());
+ n_elems = cumsum[-1].item();
+ chunk_starts = cumsum - chunk_cnts;
+ } else {
+ n_elems = 0;
+ }
+
+ // Second pass: allocate memory and write out the gaussian and pixel ids.
+ torch::Tensor gaussian_ids =
+ torch::empty({n_elems}, means2d.options().dtype(torch::kInt64));
+ torch::Tensor pixel_ids =
+ torch::empty({n_elems}, means2d.options().dtype(torch::kInt64));
+ if (n_elems) {
+ rasterize_to_indices_in_range_kernel
+ <<>>(
+ range_start,
+ range_end,
+ C,
+ N,
+ n_isects,
+ reinterpret_cast *>(means2d.data_ptr()),
+ ray_transforms.data_ptr(),
+ opacities.data_ptr(),
+ image_width,
+ image_height,
+ tile_size,
+ tile_width,
+ tile_height,
+ tile_offsets.data_ptr(),
+ flatten_ids.data_ptr(),
+ transmittances.data_ptr(),
+ chunk_starts.data_ptr(),
+ nullptr,
+ gaussian_ids.data_ptr(),
+ pixel_ids.data_ptr()
+ );
+ }
+ return std::make_tuple(gaussian_ids, pixel_ids);
+}
+
+} // namespace gsplat
\ No newline at end of file
diff --git a/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu
new file mode 100644
index 000000000..ddf508283
--- /dev/null
+++ b/gsplat/cuda/csrc/rasterize_to_pixels_2dgs_bwd.cu
@@ -0,0 +1,711 @@
+#include "bindings.h"
+#include "helpers.cuh"
+#include "types.cuh"
+#include "utils.cuh"
+#include
+#include
+#include
+
+namespace gsplat {
+
+namespace cg = cooperative_groups;
+
+/****************************************************************************
+ * Rasterization to Pixels Backward Pass 2DGS
+ ****************************************************************************/
+template
+__global__ void rasterize_to_pixels_bwd_2dgs_kernel(
+ const uint32_t C,
+ const uint32_t N,
+ const uint32_t n_isects,
+ const bool packed,
+ // fwd inputs
+ const vec2 *__restrict__ means2d, // [C, N, 2] or [nnz, 2]
+ const S *__restrict__ ray_transforms, // [C, N, 3] or [nnz, 3]
+ const S *__restrict__ colors, // [C, N, COLOR_DIM] or [nnz, COLOR_DIM]
+ const S *__restrict__ normals, // [C, N, 3] or [nnz, 3]
+ const S *__restrict__ opacities, // [C, N] or [nnz]
+ const S *__restrict__ backgrounds, // [C, COLOR_DIM] or [nnz, COLOR_DIM]
+ const bool *__restrict__ masks, // [C, tile_height, tile_width]
+ const uint32_t image_width,
+ const uint32_t image_height,
+ const uint32_t tile_size,
+ const uint32_t tile_width,
+ const uint32_t tile_height,
+ const int32_t *__restrict__ tile_offsets, // [C, tile_height, tile_width]
+ const int32_t *__restrict__ flatten_ids, // [n_isects]
+ // fwd outputs
+ const S *__restrict__ render_colors, // [C, image_height, image_width,
+ // COLOR_DIM]
+ const S *__restrict__ render_alphas, // [C, image_height, image_width, 1]
+ const int32_t *__restrict__ last_ids, // [C, image_height, image_width]
+ const int32_t *__restrict__ median_ids, // [C, image_height, image_width]
+ // grad outputs
+ const S *__restrict__ v_render_colors, // [C, image_height, image_width,
+ // COLOR_DIM]
+ const S *__restrict__ v_render_alphas, // [C, image_height, image_width, 1]
+ const S *__restrict__ v_render_normals, // [C, image_height, image_width, 3]
+ const S *__restrict__ v_render_distort, // [C, image_height, image_width, 1]
+ const S *__restrict__ v_render_median, // [C, image_height, image_width, 1]
+ // grad inputs
+ vec2