From 7e7188779ca574ea599f87fde0deeb766a6754e6 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 18 Sep 2024 16:16:33 +0200 Subject: [PATCH 01/29] initial impl without densification --- examples/simple_trainer_scaffold.py | 1069 +++++++++++++++++++++++++++ gsplat/rendering.py | 214 +++++- gsplat/strategy/__init__.py | 1 + gsplat/strategy/scaffold.py | 309 ++++++++ 4 files changed, 1592 insertions(+), 1 deletion(-) create mode 100644 examples/simple_trainer_scaffold.py create mode 100644 gsplat/strategy/scaffold.py diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py new file mode 100644 index 000000000..300842750 --- /dev/null +++ b/examples/simple_trainer_scaffold.py @@ -0,0 +1,1069 @@ +import json +import math +import os +import time +from dataclasses import dataclass, field +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import imageio +import nerfview +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +import tyro +import viser +import yaml +from datasets.colmap import Dataset, Parser +from datasets.traj import ( + generate_interpolated_path, + generate_ellipse_path_z, + generate_spiral_path, +) +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure +from fused_ssim import fused_ssim +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity +from typing_extensions import Literal +from utils import AppearanceOptModule, CameraOptModule, knn, set_random_seed +from lib_bilagrid import ( + BilateralGrid, + slice, + color_correct, + total_variation_loss, +) + +from gsplat.compression import PngCompression +from gsplat.distributed import cli +from gsplat.rendering import rasterization, filter_visible_gaussians +from gsplat.strategy import ScaffoldStrategy + + +@dataclass +class Config: + # Disable viewer + disable_viewer: bool = False + # Path to the .pt files. If provide, it will skip training and run evaluation only. + ckpt: Optional[List[str]] = None + # Name of compression strategy to use + compression: Optional[Literal["png"]] = None + # Render trajectory path + render_traj_path: str = "interp" + + # Path to the Mip-NeRF 360 dataset + data_dir: str = "/home/paja/data/bike" + # Downsample factor for the dataset + data_factor: int = 1 + # Directory to save results + result_dir: str = "results" + # Every N images there is a test image + test_every: int = 8 + # Random crop size for training (experimental) + patch_size: Optional[int] = None + # A global scaler that applies to the scene size related parameters + global_scale: float = 1.0 + # Normalize the world space + normalize_world_space: bool = True + + # Port for the viewer server + port: int = 8080 + + # Batch size for training. Learning rates are scaled automatically + batch_size: int = 1 + # A global factor to scale the number of training steps + steps_scaler: float = 1.0 + + # Number of training steps + max_steps: int = 30_000 + # Steps to evaluate the model + eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + # Steps to save the model + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + + # voxel size for Scaffold-GS + voxel_size = 0.001 + # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm + init_extent: float = 3.0 + # Turn on another SH degree every this steps + sh_degree_interval: int = 1000 + # Initial opacity of GS + init_opa: float = 0.1 + # Initial scale of GS + init_scale: float = 1.0 + # Weight for SSIM loss + ssim_lambda: float = 0.2 + + # Near plane clipping distance + near_plane: float = 0.01 + # Far plane clipping distance + far_plane: float = 1e10 + + # Strategy for GS densification + strategy: ScaffoldStrategy = field(default_factory=ScaffoldStrategy) + # Use packed mode for rasterization, this leads to less memory usage but slightly slower. + packed: bool = False + # Use sparse gradients for optimization. (experimental) + sparse_grad: bool = False + # Anti-aliasing in rasterization. Might slightly hurt quantitative metrics. + antialiased: bool = False + + # Use random background for training to discourage transparency + random_bkgd: bool = False + + # Opacity regularization + opacity_reg: float = 0.0 + # Scale regularization + scale_reg: float = 0.01 + + # Enable camera optimization. + pose_opt: bool = False + # Learning rate for camera optimization + pose_opt_lr: float = 1e-5 + # Regularization for camera optimization as weight decay + pose_opt_reg: float = 1e-6 + # Add noise to camera extrinsics. This is only to test the camera pose optimization. + pose_noise: float = 0.0 + + # Enable appearance optimization. (experimental) + app_opt: bool = False + # Appearance embedding dimension + app_embed_dim: int = 16 + # Learning rate for appearance optimization + app_opt_lr: float = 1e-3 + # Regularization for appearance optimization as weight decay + app_opt_reg: float = 1e-6 + + # Enable bilateral grid. (experimental) + use_bilateral_grid: bool = False + # Shape of the bilateral grid (X, Y, W) + bilateral_grid_shape: Tuple[int, int, int] = (16, 16, 8) + + # Enable depth loss. (experimental) + depth_loss: bool = False + # Weight for depth loss + depth_lambda: float = 1e-2 + + # Dump information to tensorboard every this steps + tb_every: int = 100 + # Save training images to tensorboard + tb_save_image: bool = False + + lpips_net: Literal["vgg", "alex"] = "alex" + + def adjust_steps(self, factor: float): + self.eval_steps = [int(i * factor) for i in self.eval_steps] + self.save_steps = [int(i * factor) for i in self.save_steps] + self.max_steps = int(self.max_steps * factor) + self.sh_degree_interval = int(self.sh_degree_interval * factor) + + strategy = self.strategy + strategy.refine_start_iter = int(strategy.refine_start_iter * factor) + strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + # strategy.reset_every = int(strategy.reset_every * factor) + # strategy.refine_every = int(strategy.refine_every * factor) + + +def create_splats_with_optimizers( + parser: Parser, + strategy: ScaffoldStrategy, + init_extent: float = 3.0, + init_opacity: float = 0.1, + init_scale: float = 1.0, + scene_scale: float = 1.0, + sh_degree: int = 3, + sparse_grad: bool = False, + batch_size: int = 1, + feature_dim: Optional[int] = None, + device: str = "cuda", + world_rank: int = 0, + world_size: int = 1, + voxel_size: float = 0.001, +) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: + + # Compare GS-Scaffold paper formula (4) + points = np.unique(np.round(parser.points/voxel_size), axis=0)*voxel_size + points = torch.from_numpy(points).float() + + # Initialize the GS size to be the average dist of the 3 nearest neighbors + dist2_avg = (knn(points, 4)[:, 1:] ** 2).mean(dim=-1) # [N,] + dist_avg = torch.sqrt(dist2_avg) + scales = torch.log(dist_avg * init_scale).unsqueeze(-1).repeat(1, 6) # [N, 3] + + # Distribute the GSs to different ranks (also works for single rank) + points = points[world_rank::world_size] + scales = scales[world_rank::world_size] + + N = points.shape[0] + quats = torch.rand((N, 4)) # [N, 4] + + features = torch.zeros((N, strategy.mean_feat_dim)) + offsets = torch.zeros((N, strategy.n_feat_offsets, 3)) + + params = [ + # name, value, lr + ("anchors", torch.nn.Parameter(points), 1.6e-4 * scene_scale), + ("scales", torch.nn.Parameter(scales), 5e-3), + ("quats", torch.nn.Parameter(quats), 1e-3), + ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.02), + ("features", torch.nn.Parameter(features), 1.6e-4 * scene_scale), + ("offsets", torch.nn.Parameter(offsets), 0.004), + ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), + ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), + ] + + splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) + # Scale learning rate based on batch size, reference: + # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ + # Note that this would not make the training exactly equivalent, see + # https://arxiv.org/pdf/2402.18824v1 + BS = batch_size * world_size + optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + eps=1e-15 / math.sqrt(BS), + # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, _, lr in params + } + return splats, optimizers + + +class Runner: + """Engine for training and testing.""" + + def __init__( + self, local_rank: int, world_rank, world_size: int, cfg: Config + ) -> None: + set_random_seed(42 + local_rank) + + self.cfg = cfg + self.world_rank = world_rank + self.local_rank = local_rank + self.world_size = world_size + self.device = f"cuda:{local_rank}" + + # Where to dump results. + os.makedirs(cfg.result_dir, exist_ok=True) + + # Setup output directories. + self.ckpt_dir = f"{cfg.result_dir}/ckpts" + os.makedirs(self.ckpt_dir, exist_ok=True) + self.stats_dir = f"{cfg.result_dir}/stats" + os.makedirs(self.stats_dir, exist_ok=True) + self.render_dir = f"{cfg.result_dir}/renders" + os.makedirs(self.render_dir, exist_ok=True) + + # Tensorboard + self.writer = SummaryWriter(log_dir=f"{cfg.result_dir}/tb") + + # Load data: Training data should contain initial points and colors. + self.parser = Parser( + data_dir=cfg.data_dir, + factor=cfg.data_factor, + normalize=cfg.normalize_world_space, + test_every=cfg.test_every, + ) + self.trainset = Dataset( + self.parser, + split="train", + patch_size=cfg.patch_size, + load_depths=cfg.depth_loss, + ) + self.valset = Dataset(self.parser, split="val") + self.scene_scale = self.parser.scene_scale * 1.1 * cfg.global_scale + print("Scene scale:", self.scene_scale) + + # Model + feature_dim = 32 if cfg.app_opt else None + self.splats, self.optimizers = create_splats_with_optimizers( + self.parser, + strategy=self.cfg.strategy, + init_extent=cfg.init_extent, + init_opacity=cfg.init_opa, + init_scale=cfg.init_scale, + scene_scale=self.scene_scale, + sparse_grad=cfg.sparse_grad, + batch_size=cfg.batch_size, + feature_dim=feature_dim, + device=self.device, + world_rank=world_rank, + world_size=world_size, + voxel_size=cfg.voxel_size, + ) + print("Model initialized. Number of GS:", len(self.splats["anchors"])) + + # Densification Strategy + self.cfg.strategy.check_sanity(self.splats, self.optimizers) + + self.strategy_state = self.cfg.strategy.initialize_state( + scene_scale=self.scene_scale + ) + # Compression Strategy + self.compression_method = None + if cfg.compression is not None: + if cfg.compression == "png": + self.compression_method = PngCompression() + else: + raise ValueError(f"Unknown compression strategy: {cfg.compression}") + + self.pose_optimizers = [] + if cfg.pose_opt: + self.pose_adjust = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_adjust.zero_init() + self.pose_optimizers = [ + torch.optim.Adam( + self.pose_adjust.parameters(), + lr=cfg.pose_opt_lr * math.sqrt(cfg.batch_size), + weight_decay=cfg.pose_opt_reg, + ) + ] + if world_size > 1: + self.pose_adjust = DDP(self.pose_adjust) + + if cfg.pose_noise > 0.0: + self.pose_perturb = CameraOptModule(len(self.trainset)).to(self.device) + self.pose_perturb.random_init(cfg.pose_noise) + if world_size > 1: + self.pose_perturb = DDP(self.pose_perturb) + + self.app_optimizers = [] + if cfg.app_opt: + assert feature_dim is not None + self.app_module = AppearanceOptModule( + len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + ).to(self.device) + # initialize the last layer to be zero so that the initial output is zero. + torch.nn.init.zeros_(self.app_module.color_head[-1].weight) + torch.nn.init.zeros_(self.app_module.color_head[-1].bias) + self.app_optimizers = [ + torch.optim.Adam( + self.app_module.embeds.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, + weight_decay=cfg.app_opt_reg, + ), + torch.optim.Adam( + self.app_module.color_head.parameters(), + lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), + ), + ] + if world_size > 1: + self.app_module = DDP(self.app_module) + + self.bil_grid_optimizers = [] + if cfg.use_bilateral_grid: + self.bil_grids = BilateralGrid( + len(self.trainset), + grid_X=cfg.bilateral_grid_shape[0], + grid_Y=cfg.bilateral_grid_shape[1], + grid_W=cfg.bilateral_grid_shape[2], + ).to(self.device) + self.bil_grid_optimizers = [ + torch.optim.Adam( + self.bil_grids.parameters(), + lr=2e-3 * math.sqrt(cfg.batch_size), + eps=1e-15, + ), + ] + + # Losses & Metrics. + self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(self.device) + self.psnr = PeakSignalNoiseRatio(data_range=1.0).to(self.device) + + if cfg.lpips_net == "alex": + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="alex", normalize=True + ).to(self.device) + elif cfg.lpips_net == "vgg": + # The 3DGS official repo uses lpips vgg, which is equivalent with the following: + self.lpips = LearnedPerceptualImagePatchSimilarity( + net_type="vgg", normalize=False + ).to(self.device) + else: + raise ValueError(f"Unknown LPIPS network: {cfg.lpips_net}") + + # Viewer + if not self.cfg.disable_viewer: + self.server = viser.ViserServer(port=cfg.port, verbose=False) + self.viewer = nerfview.Viewer( + server=self.server, + render_fn=self._viewer_render_fn, + mode="training", + ) + + def get_visibility_mask( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + packed: bool, + rasterize_mode: str, +): + anchors = self.splats["anchors"] # [N, 3] + # rasterization does normalization internally + quats = self.splats["quats"] # [N, 4] + scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3] + + visibility_mask = filter_visible_gaussians( + means=anchors, + quats=quats, + scales=scales, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=packed, + rasterize_mode=rasterize_mode, + ) + + return visibility_mask + + def get_neural_gaussians(self, cam_pos, selection=None): + + # If no visibility mask is provided, we select all anchors including their offsets + if selection is None: + selection = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) + + selected_features = self.splats["features"][selection] # [M, c] + selected_anchors = self.splats["anchors"][selection] # [M, 3] + selected_offsets = self.splats["offsets"][selection] # [M, k, 3] + selected_scales = torch.exp(self.splats["scales"][selection]) # [M, 6] + + # See formula (5) in Scaffold-GS + view_dir = selected_anchors - cam_pos # [M, 3] + view_dir_normalized = view_dir / view_dir.norm(dim=1, keepdim=True) # [M, 3] + + # See formula (9) and the appendix for the rest + feature_view_dir = torch.cat([selected_features, view_dir_normalized], dim=1) # [M, c+3] + + k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor + + # Apply MLPs (they output per-offset features concatenated along the last dimension) + neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] + neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] + pos_opacity_mask = (neural_opacity > 0.0).view(-1) # [M*k] + + # Get color and reshape + neural_colors = self.cfg.strategy.colors_mlp(feature_view_dir) # [M, k*3] + neural_colors = neural_colors.view(-1, 3) # [M*k, 3] + + # Get scale and rotation and reshape + neural_scale_rot = self.cfg.strategy.scale_rot_mlp(feature_view_dir) # [M, k*7] + neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] + + # Reshape selected_offsets, scales, and anchors + selected_offsets = selected_offsets.view(-1, 3) # [M*k, 3] + scales_repeated = selected_scales.unsqueeze(1).repeat(1, k, 1).view(-1, 6) # [M*k, 6] + anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3] + + # Apply positive opacity mask + selected_opacity = neural_opacity[pos_opacity_mask].squeeze(-1) # [m] + selected_colors = neural_colors[pos_opacity_mask] # [m, 3] + selected_scale_rot = neural_scale_rot[pos_opacity_mask] # [m, 7] + selected_offsets = selected_offsets[pos_opacity_mask] # [m, 3] + scales_repeated = scales_repeated[pos_opacity_mask] # [m, 6] + anchors_repeated = anchors_repeated[pos_opacity_mask] # [m, 3] + + # Compute scales and rotations + scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3] + rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4] + + # Compute offsets and anchors + offsets = selected_offsets * scales_repeated[:, :3] # [m, 3] + anchors = anchors_repeated + offsets # [m, 3] + + return anchors, selected_colors, selected_opacity, scales, rotation, neural_opacity, pos_opacity_mask + + def rasterize_splats( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + **kwargs, + ) -> Tuple[Tensor, Tensor, Dict, Tensor]: + + visibility_mask = self.get_visibility_mask(camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + packed=self.cfg.packed, + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", + ) + + anchors, color_mlp, opacities, scales, quats, neural_opacity, selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], selection=visibility_mask) + + image_ids = kwargs.pop("image_ids", None) + if self.cfg.app_opt: + colors = self.app_module( + features=self.splats["features"], + embed_ids=image_ids, + dirs=anchors[None, :, :] - camtoworlds[:, None, :3, 3], + sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + ) + colors = colors + color_mlp + colors = torch.sigmoid(colors) + else: + colors = color_mlp # [N, K, 3] + + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" + render_colors, render_alphas, info = rasterization( + means=anchors, + quats=quats, + scales=scales, + opacities=opacities, + colors=colors, + viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] + Ks=Ks, # [C, 3, 3] + width=width, + height=height, + packed=self.cfg.packed, + absgrad=self.cfg.strategy.absgrad, + sparse_grad=self.cfg.sparse_grad, + rasterize_mode=rasterize_mode, + distributed=self.world_size > 1, + **kwargs, + ) + return render_colors, render_alphas, info, scales + + def train(self): + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + # Dump cfg. + if world_rank == 0: + with open(f"{cfg.result_dir}/cfg.yml", "w") as f: + yaml.dump(vars(cfg), f) + + max_steps = cfg.max_steps + init_step = 0 + + schedulers = [ + # anchors has a learning rate schedule, that end at 0.01 of the initial value + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["anchors"], gamma=0.01 ** (1.0 / max_steps) + ), + ] + if cfg.pose_opt: + # pose optimization has a learning rate schedule + schedulers.append( + torch.optim.lr_scheduler.ExponentialLR( + self.pose_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ) + ) + if cfg.use_bilateral_grid: + # bilateral grid has a learning rate schedule. Linear warmup for 1000 steps. + schedulers.append( + torch.optim.lr_scheduler.ChainedScheduler( + [ + torch.optim.lr_scheduler.LinearLR( + self.bil_grid_optimizers[0], + start_factor=0.01, + total_iters=1000, + ), + torch.optim.lr_scheduler.ExponentialLR( + self.bil_grid_optimizers[0], gamma=0.01 ** (1.0 / max_steps) + ), + ] + ) + ) + + trainloader = torch.utils.data.DataLoader( + self.trainset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=4, + persistent_workers=True, + pin_memory=True, + ) + trainloader_iter = iter(trainloader) + + # Training loop. + global_tic = time.time() + pbar = tqdm.tqdm(range(init_step, max_steps)) + for step in pbar: + if not cfg.disable_viewer: + while self.viewer.state.status == "paused": + time.sleep(0.01) + self.viewer.lock.acquire() + tic = time.time() + + try: + data = next(trainloader_iter) + except StopIteration: + trainloader_iter = iter(trainloader) + data = next(trainloader_iter) + + camtoworlds = camtoworlds_gt = data["camtoworld"].to(device) # [1, 4, 4] + Ks = data["K"].to(device) # [1, 3, 3] + pixels = data["image"].to(device) / 255.0 # [1, H, W, 3] + num_train_rays_per_step = ( + pixels.shape[0] * pixels.shape[1] * pixels.shape[2] + ) + image_ids = data["image_id"].to(device) + if cfg.depth_loss: + points = data["points"].to(device) # [1, M, 2] + depths_gt = data["depths"].to(device) # [1, M] + + height, width = pixels.shape[1:3] + + if cfg.pose_noise: + camtoworlds = self.pose_perturb(camtoworlds, image_ids) + + if cfg.pose_opt: + camtoworlds = self.pose_adjust(camtoworlds, image_ids) + + # forward + renders, alphas, info, scales = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=None, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + image_ids=image_ids, + render_mode="RGB+ED" if cfg.depth_loss else "RGB", + ) + if renders.shape[-1] == 4: + colors, depths = renders[..., 0:3], renders[..., 3:4] + else: + colors, depths = renders, None + + if cfg.use_bilateral_grid: + grid_y, grid_x = torch.meshgrid( + (torch.arange(height, device=self.device) + 0.5) / height, + (torch.arange(width, device=self.device) + 0.5) / width, + indexing="ij", + ) + grid_xy = torch.stack([grid_x, grid_y], dim=-1).unsqueeze(0) + colors = slice(self.bil_grids, grid_xy, colors, image_ids)["rgb"] + + if cfg.random_bkgd: + bkgd = torch.rand(1, 3, device=device) + colors = colors + bkgd * (1.0 - alphas) + + self.cfg.strategy.step_pre_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + ) + + # loss + l1loss = F.l1_loss(colors, pixels) + ssimloss = 1.0 - fused_ssim( + colors.permute(0, 3, 1, 2), pixels.permute(0, 3, 1, 2), padding="valid" + ) + loss = l1loss * (1.0 - cfg.ssim_lambda) + loss += ssimloss * cfg.ssim_lambda + loss += scales.prod(dim=1).mean() * cfg.scale_reg + if cfg.depth_loss: + # query depths from depth map + points = torch.stack( + [ + points[:, :, 0] / (width - 1) * 2 - 1, + points[:, :, 1] / (height - 1) * 2 - 1, + ], + dim=-1, + ) # normalize to [-1, 1] + grid = points.unsqueeze(2) # [1, M, 1, 2] + depths = F.grid_sample( + depths.permute(0, 3, 1, 2), grid, align_corners=True + ) # [1, 1, M, 1] + depths = depths.squeeze(3).squeeze(1) # [1, M] + # calculate loss in disparity space + disp = torch.where(depths > 0.0, 1.0 / depths, torch.zeros_like(depths)) + disp_gt = 1.0 / depths_gt # [1, M] + depthloss = F.l1_loss(disp, disp_gt) * self.scene_scale + loss += depthloss * cfg.depth_lambda + if cfg.use_bilateral_grid: + tvloss = 10 * total_variation_loss(self.bil_grids.grids) + loss += tvloss + + # regularizations + # not gonna work. Check this + # if cfg.opacity_reg > 0.0: + # loss = ( + # loss + # + cfg.opacity_reg + # * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() + # ) + + loss.backward() + + desc = f"loss={loss.item():.3f}| " + if cfg.depth_loss: + desc += f"depth loss={depthloss.item():.6f}| " + if cfg.pose_opt and cfg.pose_noise: + # monitor the pose error if we inject noise + pose_err = F.l1_loss(camtoworlds_gt, camtoworlds) + desc += f"pose err={pose_err.item():.6f}| " + pbar.set_description(desc) + + if world_rank == 0 and cfg.tb_every > 0 and step % cfg.tb_every == 0: + mem = torch.cuda.max_memory_allocated() / 1024**3 + self.writer.add_scalar("train/loss", loss.item(), step) + self.writer.add_scalar("train/l1loss", l1loss.item(), step) + self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) + self.writer.add_scalar("train/num_GS", len(self.splats["anchors"]), step) + self.writer.add_scalar("train/mem", mem, step) + if cfg.depth_loss: + self.writer.add_scalar("train/depthloss", depthloss.item(), step) + if cfg.use_bilateral_grid: + self.writer.add_scalar("train/tvloss", tvloss.item(), step) + if cfg.tb_save_image: + canvas = torch.cat([pixels, colors], dim=2).detach().cpu().numpy() + canvas = canvas.reshape(-1, *canvas.shape[2:]) + self.writer.add_image("train/render", canvas, step) + self.writer.flush() + + # save checkpoint before updating the model + if step in [i - 1 for i in cfg.save_steps] or step == max_steps - 1: + mem = torch.cuda.max_memory_allocated() / 1024**3 + stats = { + "mem": mem, + "ellipse_time": time.time() - global_tic, + "num_GS": len(self.splats["anchors"]), + } + print("Step: ", step, stats) + with open( + f"{self.stats_dir}/train_step{step:04d}_rank{self.world_rank}.json", + "w", + ) as f: + json.dump(stats, f) + data = {"step": step, "splats": self.splats.state_dict()} + if cfg.pose_opt: + if world_size > 1: + data["pose_adjust"] = self.pose_adjust.module.state_dict() + else: + data["pose_adjust"] = self.pose_adjust.state_dict() + if cfg.app_opt: + if world_size > 1: + data["app_module"] = self.app_module.module.state_dict() + else: + data["app_module"] = self.app_module.state_dict() + torch.save( + data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" + ) + + # For now no post steps + # self.cfg.strategy.step_post_backward( + # params=self.splats, + # optimizers=self.optimizers, + # state=self.strategy_state, + # step=step, + # info=info, + # packed=cfg.packed, + # ) + + # Turn Gradients into Sparse Tensor before running optimizer + if cfg.sparse_grad: + assert cfg.packed, "Sparse gradients only work with packed mode." + gaussian_ids = info["gaussian_ids"] + for k in self.splats.keys(): + grad = self.splats[k].grad + if grad is None or grad.is_sparse: + continue + self.splats[k].grad = torch.sparse_coo_tensor( + indices=gaussian_ids[None], # [1, nnz] + values=grad[gaussian_ids], # [nnz, ...] + size=self.splats[k].size(), # [N, ...] + is_coalesced=len(Ks) == 1, + ) + + # optimize + for optimizer in self.optimizers.values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.pose_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.app_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for optimizer in self.bil_grid_optimizers: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + for scheduler in schedulers: + scheduler.step() + + # eval the full set + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_traj(step) + + # run compression + if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: + self.run_compression(step=step) + + if not cfg.disable_viewer: + self.viewer.lock.release() + num_train_steps_per_sec = 1.0 / (time.time() - tic) + num_train_rays_per_sec = ( + num_train_rays_per_step * num_train_steps_per_sec + ) + # Update the viewer state. + self.viewer.state.num_train_rays_per_sec = num_train_rays_per_sec + # Update the scene. + self.viewer.update(step, num_train_rays_per_step) + + @torch.no_grad() + def eval(self, step: int, stage: str = "val"): + """Entry for evaluation.""" + print("Running evaluation...") + cfg = self.cfg + device = self.device + world_rank = self.world_rank + world_size = self.world_size + + valloader = torch.utils.data.DataLoader( + self.valset, batch_size=1, shuffle=False, num_workers=1 + ) + ellipse_time = 0 + metrics = defaultdict(list) + for i, data in enumerate(valloader): + camtoworlds = data["camtoworld"].to(device) + Ks = data["K"].to(device) + pixels = data["image"].to(device) / 255.0 + height, width = pixels.shape[1:3] + + torch.cuda.synchronize() + tic = time.time() + colors, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + ) # [1, H, W, 3] + torch.cuda.synchronize() + ellipse_time += time.time() - tic + + colors = torch.clamp(colors, 0.0, 1.0) + canvas_list = [pixels, colors] + + if world_rank == 0: + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + imageio.imwrite( + f"{self.render_dir}/{stage}_step{step}_{i:04d}.png", + canvas, + ) + + pixels_p = pixels.permute(0, 3, 1, 2) # [1, 3, H, W] + colors_p = colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["psnr"].append(self.psnr(colors_p, pixels_p)) + metrics["ssim"].append(self.ssim(colors_p, pixels_p)) + metrics["lpips"].append(self.lpips(colors_p, pixels_p)) + if cfg.use_bilateral_grid: + cc_colors = color_correct(colors, pixels) + cc_colors_p = cc_colors.permute(0, 3, 1, 2) # [1, 3, H, W] + metrics["cc_psnr"].append(self.psnr(cc_colors_p, pixels_p)) + + if world_rank == 0: + ellipse_time /= len(valloader) + + stats = {k: torch.stack(v).mean().item() for k, v in metrics.items()} + stats.update( + { + "ellipse_time": ellipse_time, + "num_GS": len(self.splats["anchors"]), + } + ) + print( + f"PSNR: {stats['psnr']:.3f}, SSIM: {stats['ssim']:.4f}, LPIPS: {stats['lpips']:.3f} " + f"Time: {stats['ellipse_time']:.3f}s/image " + f"Number of GS: {stats['num_GS']}" + ) + # save stats as json + with open(f"{self.stats_dir}/{stage}_step{step:04d}.json", "w") as f: + json.dump(stats, f) + # save stats to tensorboard + for k, v in stats.items(): + self.writer.add_scalar(f"{stage}/{k}", v, step) + self.writer.flush() + + @torch.no_grad() + def render_traj(self, step: int): + """Entry for trajectory rendering.""" + print("Running trajectory rendering...") + cfg = self.cfg + device = self.device + + camtoworlds_all = self.parser.camtoworlds[5:-5] + if cfg.render_traj_path == "interp": + camtoworlds_all = generate_interpolated_path( + camtoworlds_all, 1 + ) # [N, 3, 4] + elif cfg.render_traj_path == "ellipse": + height = camtoworlds_all[:, 2, 3].mean() + camtoworlds_all = generate_ellipse_path_z( + camtoworlds_all, height=height + ) # [N, 3, 4] + elif cfg.render_traj_path == "spiral": + camtoworlds_all = generate_spiral_path( + camtoworlds_all, + bounds=self.parser.bounds * self.scene_scale, + spiral_scale_r=self.parser.extconf["spiral_radius_scale"], + ) + else: + raise ValueError( + f"Render trajectory type not supported: {cfg.render_traj_path}" + ) + + camtoworlds_all = np.concatenate( + [ + camtoworlds_all, + np.repeat( + np.array([[[0.0, 0.0, 0.0, 1.0]]]), len(camtoworlds_all), axis=0 + ), + ], + axis=1, + ) # [N, 4, 4] + + camtoworlds_all = torch.from_numpy(camtoworlds_all).float().to(device) + K = torch.from_numpy(list(self.parser.Ks_dict.values())[0]).float().to(device) + width, height = list(self.parser.imsize_dict.values())[0] + + canvas_all = [] + for i in tqdm.trange(len(camtoworlds_all), desc="Rendering trajectory"): + camtoworlds = camtoworlds_all[i : i + 1] + Ks = K[None] + + renders, _, _, _ = self.rasterize_splats( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + sh_degree=cfg.sh_degree, + near_plane=cfg.near_plane, + far_plane=cfg.far_plane, + render_mode="RGB+ED", + ) # [1, H, W, 4] + colors = torch.clamp(renders[..., 0:3], 0.0, 1.0) # [1, H, W, 3] + depths = renders[..., 3:4] # [1, H, W, 1] + depths = (depths - depths.min()) / (depths.max() - depths.min()) + canvas_list = [colors, depths.repeat(1, 1, 1, 3)] + + # write images + canvas = torch.cat(canvas_list, dim=2).squeeze(0).cpu().numpy() + canvas = (canvas * 255).astype(np.uint8) + canvas_all.append(canvas) + + # save to video + video_dir = f"{cfg.result_dir}/videos" + os.makedirs(video_dir, exist_ok=True) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + for canvas in canvas_all: + writer.append_data(canvas) + writer.close() + print(f"Video saved to {video_dir}/traj_{step}.mp4") + + @torch.no_grad() + def run_compression(self, step: int): + """Entry for running compression.""" + print("Running compression...") + world_rank = self.world_rank + + compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" + os.makedirs(compress_dir, exist_ok=True) + + self.compression_method.compress(compress_dir, self.splats) + + # evaluate compression + splats_c = self.compression_method.decompress(compress_dir) + for k in splats_c.keys(): + self.splats[k].data = splats_c[k].to(self.device) + self.eval(step=step, stage="compress") + + @torch.no_grad() + def _viewer_render_fn( + self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] + ): + """Callable function for the viewer.""" + W, H = img_wh + c2w = camera_state.c2w + K = camera_state.get_K(img_wh) + c2w = torch.from_numpy(c2w).float().to(self.device) + K = torch.from_numpy(K).float().to(self.device) + + render_colors, _, _, _ = self.rasterize_splats( + camtoworlds=c2w[None], + Ks=K[None], + width=W, + height=H, + sh_degree=None, # active all SH degrees + radius_clip=3.0, # skip GSs that have small image radius (in pixels) + ) # [1, H, W, 3] + return render_colors[0].cpu().numpy() + + +def main(local_rank: int, world_rank, world_size: int, cfg: Config): + if world_size > 1 and not cfg.disable_viewer: + cfg.disable_viewer = True + if world_rank == 0: + print("Viewer is disabled in distributed training.") + + runner = Runner(local_rank, world_rank, world_size, cfg) + + if cfg.ckpt is not None: + # run eval only + ckpts = [ + torch.load(file, map_location=runner.device, weights_only=True) + for file in cfg.ckpt + ] + for k in runner.splats.keys(): + runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + step = ckpts[0]["step"] + runner.eval(step=step) + runner.render_traj(step=step) + if cfg.compression is not None: + runner.run_compression(step=step) + else: + runner.train() + + if not cfg.disable_viewer: + print("Viewer running... Ctrl+C to exit.") + time.sleep(1000000) + + +if __name__ == "__main__": + """ + Usage: + + ```bash + # Single GPU training + CUDA_VISIBLE_DEVICES=0 python simple_trainer.py default + + # Distributed training on 4 GPUs: Effectively 4x batch size so run 4x less steps. + CUDA_VISIBLE_DEVICES=0,1,2,3 python simple_trainer.py default --steps_scaler 0.25 + + """ + + cfg = tyro.cli(Config) + cfg.adjust_steps(cfg.steps_scaler) + + # try import extra dependencies + if cfg.compression == "png": + try: + import plas + import torchpq + except: + raise ImportError( + "To use PNG compression, you need to install " + "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " + "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " + ) + + cli(main, cfg, verbose=True) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index cadacaa2f..174f8121d 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -582,6 +582,218 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso return render_colors, render_alphas, meta +def filter_visible_gaussians( + means: Tensor, # [N, 3] + quats: Tensor, # [N, 4] + scales: Tensor, # [N, 3] + viewmats: Tensor, # [C, 4, 4] + Ks: Tensor, # [C, 3, 3] + width: int, + height: int, + near_plane: float = 0.01, + far_plane: float = 1e10, + radius_clip: float = 0.0, + eps2d: float = 0.3, + sh_degree: Optional[int] = None, + packed: bool = True, + rasterize_mode: Literal["classic", "antialiased"] = "classic", + ortho: bool = False, + covars: Optional[Tensor] = None, +) -> Tuple[Tensor, Tensor, Dict]: + """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). + + This function provides a handful features for 3D Gaussian rasterization, which + we detail in the following notes. A complete profiling of the these features + can be found in the :ref:`profiling` page. + + .. note:: + **Multi-GPU Distributed Rasterization**: This function can be used in a multi-GPU + distributed scenario by setting `distributed` to True. When `distributed` is True, + a subset of total Gaussians could be passed into this function in each rank, and + the function will collaboratively render a set of images using Gaussians from all ranks. Note + to achieve balanced computation, it is recommended (not enforced) to have similar number of + Gaussians in each rank. But we do enforce that the number of cameras to be rendered + in each rank is the same. The function will return the rendered images + corresponds to the input cameras in each rank, and allows for gradients to flow back to the + Gaussians living in other ranks. For the details, please refer to the paper + `On Scaling Up 3D Gaussian Splatting Training `_. + + .. note:: + **Batch Rasterization**: This function allows for rasterizing a set of 3D Gaussians + to a batch of images in one go, by simplly providing the batched `viewmats` and `Ks`. + + .. note:: + **Support N-D Features**: If `sh_degree` is None, + the `colors` is expected to be with shape [N, D] or [C, N, D], in which D is the channel of + the features to be rendered. The computation is slow when D > 32 at the moment. + If `sh_degree` is set, the `colors` is expected to be the SH coefficients with + shape [N, K, 3] or [C, N, K, 3], where K is the number of SH bases. In this case, it is expected + that :math:`(\\textit{sh_degree} + 1) ^ 2 \\leq K`, where `sh_degree` controls the + activated bases in the SH coefficients. + + .. note:: + **Depth Rendering**: This function supports colors or/and depths via `render_mode`. + The supported modes are "RGB", "D", "ED", "RGB+D", and "RGB+ED". "RGB" renders the + colored image that respects the `colors` argument. "D" renders the accumulated z-depth + :math:`\\sum_i w_i z_i`. "ED" renders the expected z-depth + :math:`\\frac{\\sum_i w_i z_i}{\\sum_i w_i}`. "RGB+D" and "RGB+ED" render both + the colored image and the depth, in which the depth is the last channel of the output. + + .. note:: + **Memory-Speed Trade-off**: The `packed` argument provides a trade-off between + memory footprint and runtime. If `packed` is True, the intermediate results are + packed into sparse tensors, which is more memory efficient but might be slightly + slower. This is especially helpful when the scene is large and each camera sees only + a small portion of the scene. If `packed` is False, the intermediate results are + with shape [C, N, ...], which is faster but might consume more memory. + + .. note:: + **Sparse Gradients**: If `sparse_grad` is True, the gradients for {means, quats, scales} + will be stored in a `COO sparse layout `_. + This can be helpful for saving memory + for training when the scene is large and each iteration only activates a small portion + of the Gaussians. Usually a sparse optimizer is required to work with sparse gradients, + such as `torch.optim.SparseAdam `_. + This argument is only effective when `packed` is True. + + .. note:: + **Speed-up for Large Scenes**: The `radius_clip` argument is extremely helpful for + speeding up large scale scenes or scenes with large depth of fields. Gaussians with + 2D radius smaller or equal than this value (in pixel unit) will be skipped during rasterization. + This will skip all the far-away Gaussians that are too small to be seen in the image. + But be warned that if there are close-up Gaussians that are also below this threshold, they will + also get skipped (which is rarely happened in practice). This is by default disabled by setting + `radius_clip` to 0.0. + + .. note:: + **Antialiased Rendering**: If `rasterize_mode` is "antialiased", the function will + apply a view-dependent compensation factor + :math:`\\rho=\\sqrt{\\frac{Det(\\Sigma)}{Det(\\Sigma+ \\epsilon I)}}` to Gaussian + opacities, where :math:`\\Sigma` is the projected 2D covariance matrix and :math:`\\epsilon` + is the `eps2d`. This will make the rendered image more antialiased, as proposed in + the paper `Mip-Splatting: Alias-free 3D Gaussian Splatting `_. + + .. note:: + **AbsGrad**: If `absgrad` is True, the absolute gradients of the projected + 2D means will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. This is an implementation of the paper + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_, + which is shown to be more effective for splitting Gaussians during training. + + .. warning:: + This function is currently not differentiable w.r.t. the camera intrinsics `Ks`. + + Args: + means: The 3D centers of the Gaussians. [N, 3] + quats: The quaternions of the Gaussians (wxyz convension). It's not required to be normalized. [N, 4] + scales: The scales of the Gaussians. [N, 3] + viewmats: The world-to-cam transformation of the cameras. [C, 4, 4] + Ks: The camera intrinsics. [C, 3, 3] + width: The width of the image. + height: The height of the image. + near_plane: The near plane for clipping. Default is 0.01. + far_plane: The far plane for clipping. Default is 1e10. + radius_clip: Gaussians with 2D radius smaller or equal than this value will be + skipped. This is extremely helpful for speeding up large scale scenes. + Default is 0.0. + eps2d: An epsilon added to the egienvalues of projected 2D covariance matrices. + This will prevents the projected GS to be too small. For example eps2d=0.3 + leads to minimal 3 pixel unit. Default is 0.3. + sh_degree: The SH degree to use, which can be smaller than the total + number of bands. If set, the `colors` should be [(C,) N, K, 3] SH coefficients, + else the `colors` should [(C,) N, D] post-activation color values. Default is None. + packed: Whether to use packed mode which is more memory efficient but might or + might not be as fast. Default is True. + tile_size: The size of the tiles for rasterization. Default is 16. + (Note: other values are not tested) + backgrounds: The background colors. [C, D]. Default is None. + render_mode: The rendering mode. Supported modes are "RGB", "D", "ED", "RGB+D", + and "RGB+ED". "RGB" renders the colored image, "D" renders the accumulated depth, and + "ED" renders the expected depth. Default is "RGB". + sparse_grad: If true, the gradients for {means, quats, scales} will be stored in + a COO sparse layout. This can be helpful for saving memory. Default is False. + absgrad: If true, the absolute gradients of the projected 2D means + will be computed during the backward pass, which could be accessed by + `meta["means2d"].absgrad`. Default is False. + rasterize_mode: The rasterization mode. Supported modes are "classic" and + "antialiased". Default is "classic". + channel_chunk: The number of channels to render in one go. Default is 32. + If the required rendering channels are larger than this value, the rendering + will be done looply in chunks. + distributed: Whether to use distributed rendering. Default is False. If True, + The input Gaussians are expected to be a subset of scene in each rank, and + the function will collaboratively render the images for all ranks. + ortho: Whether to use orthographic projection. In such case fx and fy become the scaling + factors to convert projected coordinates into pixel space and cx, cy become offsets. + covars: Optional covariance matrices of the Gaussians. If provided, the `quats` and + `scales` will be ignored. [N, 3, 3], Default is None. + + Returns: + A tuple: + + **render_colors**: The rendered colors. [C, height, width, X]. + X depends on the `render_mode` and input `colors`. If `render_mode` is "RGB", + X is D; if `render_mode` is "D" or "ED", X is 1; if `render_mode` is "RGB+D" or + "RGB+ED", X is D+1. + + **render_alphas**: The rendered alphas. [C, height, width, 1]. + + **meta**: A dictionary of intermediate results of the rasterization. + + """ + N = means.shape[0] + C = viewmats.shape[0] + assert means.shape == (N, 3), means.shape + if covars is None: + assert quats.shape == (N, 4), quats.shape + assert scales.shape == (N, 3), scales.shape + else: + assert covars.shape == (N, 3, 3), covars.shape + quats, scales = None, None + # convert covars from 3x3 matrix to upper-triangular 6D vector + tri_indices = ([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]) + covars = covars[..., tri_indices[0], tri_indices[1]] + assert viewmats.shape == (C, 4, 4), viewmats.shape + assert Ks.shape == (C, 3, 3), Ks.shape + + # Project Gaussians to 2D. Directly pass in {quats, scales} is faster than precomputing covars. + proj_results = fully_fused_projection( + means, + covars, + quats, + scales, + viewmats, + Ks, + width, + height, + eps2d=eps2d, + packed=packed, + near_plane=near_plane, + far_plane=far_plane, + radius_clip=radius_clip, + sparse_grad=False, + calc_compensations=(rasterize_mode == "antialiased"), + ortho=ortho, + ) + + if packed: + # The results are packed into shape [nnz, ...]. All elements are valid. + ( + _, + _, + radii, + _, + _, + _, + _, + ) = proj_results + else: + # The results are with shape [C, N, ...]. Only the elements with radii > 0 are valid. + radii, _, _, _, _ = proj_results + + return radii.squeeze(0) > 0 + + def _rasterization( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] @@ -1447,7 +1659,7 @@ def rasterization_2dgs_inria_wrapper( render_depth_expected * (1 - depth_ratio) + (depth_ratio) * render_depth_median ) - normals_surf = depth_to_normal(render_depth, torch.linalg.inv(viewmats), Ks) + normals_surf = depth_to_normal(render_depth, viewmats, Ks) normals_surf = normals_surf * (render_alphas).detach() render_colors = torch.cat([render_colors, render_depth], dim=-1) diff --git a/gsplat/strategy/__init__.py b/gsplat/strategy/__init__.py index 305dc8129..08ac72b8f 100644 --- a/gsplat/strategy/__init__.py +++ b/gsplat/strategy/__init__.py @@ -1,3 +1,4 @@ from .base import Strategy from .default import DefaultStrategy from .mcmc import MCMCStrategy +from .scaffold import ScaffoldStrategy diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py new file mode 100644 index 000000000..b62c777ff --- /dev/null +++ b/gsplat/strategy/scaffold.py @@ -0,0 +1,309 @@ +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +import torch + +from .base import Strategy +from .ops import duplicate, remove + + +@dataclass +class ScaffoldStrategy(Strategy): + """A neural gaussian strategy that follows the paper: + + `Scaffold-GS: Structured 3D Gaussians for View-Adaptive Rendering `_ + + The strategy will: + + - Periodically duplicate GSs with high image plane gradients and small scales. + - Periodically split GSs with high image plane gradients and large scales. + - Periodically prune GSs with low opacity. + - Periodically reset GSs to a lower opacity. + + If `absgrad=True`, it will use the absolute gradients instead of average gradients + for GS duplicating & splitting, following the AbsGS paper: + + `AbsGS: Recovering Fine Details for 3D Gaussian Splatting `_ + + Which typically leads to better results but requires to set the `grow_grad2d` to a + higher value, e.g., 0.0008. Also, the :func:`rasterization` function should be called + with `absgrad=True` as well so that the absolute gradients are computed. + + Args: + prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. + grow_grad2d (float): GSs with image plane gradient above this value will be + split/duplicated. Default is 0.0002. + grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this + value will be duplicated. Above will be split. Default is 0.01. + grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be split. Default is 0.05. + prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this + value will be pruned. Default is 0.1. + prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above + this value will be pruned. Default is 0.15. + refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this + iteration. Default is 0. Set to a positive value to enable this feature. + refine_start_iter (int): Start refining GSs after this iteration. Default is 500. + refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. + refine_every (int): Refine GSs every this steps. Default is 100. + pause_refine_after_reset (int): Pause refining GSs until this number of steps after + reset, Default is 0 (no pause at all) and one might want to set this number to the + number of images in training set. + absgrad (bool): Use absolute gradients for GS splitting. Default is False. + revised_opacity (bool): Whether to use revised opacity heuristic from + arXiv:2404.06109 (experimental). Default is False. + verbose (bool): Whether to print verbose information. Default is False. + + Examples: + + >>> from gsplat import DefaultStrategy, rasterization + >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... + >>> optimizers: Dict[str, torch.optim.Optimizer] = ... + >>> strategy = DefaultStrategy() + >>> strategy.check_sanity(params, optimizers) + >>> strategy_state = strategy.initialize_state() + >>> for step in range(1000): + ... render_image, render_alpha, info = rasterization(...) + ... strategy.step_pre_backward(params, optimizers, strategy_state, step, info) + ... loss = ... + ... loss.backward() + ... strategy.step_post_backward(params, optimizers, strategy_state, step, info) + + """ + + prune_opa: float = 0.005 + grow_grad2d: float = 0.0002 + grow_scale3d: float = 0.01 + grow_scale2d: float = 0.05 + prune_scale3d: float = 0.1 + prune_scale2d: float = 0.15 + refine_scale2d_stop_iter: int = 0 + refine_start_iter: int = 500 + refine_stop_iter: int = 15_000 + mean_feat_dim: int = 32 + n_feat_offsets: int = 10 + refine_every: int = 100 + pause_refine_after_reset: int = 0 + absgrad: bool = False + revised_opacity: bool = False + verbose: bool = False + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, 3 * n_feat_offsets), + torch.nn.Sigmoid() + ).cuda() + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, n_feat_offsets), + torch.nn.Tanh() + ).cuda() + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(mean_feat_dim, 7 * n_feat_offsets) + ).cuda() + + def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: + """Initialize and return the running state for this strategy. + + The returned state should be passed to the `step_pre_backward()` and + `step_post_backward()` functions. + """ + # Postpone the initialization of the state to the first step so that we can + # put them on the correct device. + # - grad2d: running accum of the norm of the image plane gradients for each GS. + # - count: running accum of how many time each GS is visible. + # - radii: the radii of the GSs (normalized by the image resolution). + state = {"grad2d": None, + "count": None, + "scene_scale": scene_scale} + if self.refine_scale2d_stop_iter > 0: + state["radii"] = None + return state + + def check_sanity( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + ): + """Sanity check for the parameters and optimizers. + + Check if: + * `params` and `optimizers` have the same keys. + * Each optimizer has exactly one param_group, corresponding to each parameter. + * The following keys are present: {"anchors", "scales", "quats", "opacities"}. + + Raises: + AssertionError: If any of the above conditions is not met. + + .. note:: + It is not required but highly recommended for the user to call this function + after initializing the strategy to ensure the convention of the parameters + and optimizers is as expected. + """ + + super().check_sanity(params, optimizers) + # The following keys are required for this strategy. + expected_params = ["anchors", + "features", + "offsets", + "scales", + "quats", + "opacities_mlp", + "colors_mlp", + "scale_rot_mlp"] + + assert len(expected_params) == len(params), "expected params and actual params don't match" + for key in expected_params: + assert key in params, f"{key} is required in params but missing." + + def step_pre_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + ): + """Callback function to be executed before the `loss.backward()` call.""" + assert ( + "means2d" in info + ), "The 2D anchors of the Gaussians is required but missing." + info["means2d"].retain_grad() + + def step_post_backward( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + info: Dict[str, Any], + packed: bool = False, + ): + """Callback function to be executed after the `loss.backward()` call.""" + if step >= self.refine_stop_iter: + return + + self._update_state(params, state, info, packed=packed) + + if ( + step > self.refine_start_iter + and step % self.refine_every == 0 + ): + # grow GSs + n_dupli, n_split = self._grow_anchors(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Now having {len(params['anchors'])} GSs." + ) + + # prune GSs + n_prune = self._prune_gs(params, optimizers, state, step) + if self.verbose: + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(params['anchors'])} GSs." + ) + + # reset running stats + state["grad2d"].zero_() + state["count"].zero_() + + if self.refine_scale2d_stop_iter > 0: + state["radii"].zero_() + torch.cuda.empty_cache() + + def _update_state( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, + ): + for key in ["anchors2d", "width", "height", "n_cameras", "radii", "gaussian_ids"]: + assert key in info, f"{key} is required but missing." + + # normalize grads to [-1, 1] screen space + if self.absgrad: + grads = info["anchors2d"].absgrad.clone() + else: + grads = info["anchors2d"].grad.clone() + grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] + grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] + + # initialize state on the first run + n_gaussian = len(list(params["anchors"].shape[0])) + if state["grad2d"] is None: + state["grad2d"] = torch.zeros(n_gaussian, device=grads.device) + if state["count"] is None: + state["count"] = torch.zeros(n_gaussian, device=grads.device) + if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: + assert "radii" in info, "radii is required but missing." + state["radii"] = torch.zeros(n_gaussian, device=grads.device) + + # update the running state + if packed: + # grads is [nnz, 2] + gs_ids = info["gaussian_ids"] # [nnz] + radii = info["radii"] # [nnz] + else: + # grads is [C, N, 2] + sel = info["radii"] > 0.0 # [C, N] + gs_ids = torch.where(sel)[1] # [nnz] + grads = grads[sel] # [nnz, 2] + radii = info["radii"][sel] # [nnz] + + state["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) + state["count"].index_add_( + 0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32) + ) + if self.refine_scale2d_stop_iter > 0: + # Should be ideally using scatter max + state["radii"][gs_ids] = torch.maximum( + state["radii"][gs_ids], + # normalize radii to [0, 1] screen space + radii / float(max(info["width"], info["height"])), + ) + + @torch.no_grad() + def _grow_anchors( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> Tuple[int, int]: + pass + + @torch.no_grad() + def _prune_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Any], + step: int, + ) -> int: + is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa + if step > self.reset_every: + is_too_big = ( + torch.exp(params["scales"]).max(dim=-1).values + > self.prune_scale3d * state["scene_scale"] + ) + # The official code also implements sreen-size pruning but + # it's actually not being used due to a bug: + # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 + # We implement it here for completeness but set `refine_scale2d_stop_iter` + # to 0 by default to disable it. + if step < self.refine_scale2d_stop_iter: + is_too_big |= state["radii"] > self.prune_scale2d + + is_prune = is_prune | is_too_big + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove(params=params, optimizers=optimizers, state=state, mask=is_prune) + + return n_prune From f8fbb8fad3f77f15cb1c4c837a81563a268643f1 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 18 Sep 2024 23:28:42 +0200 Subject: [PATCH 02/29] implement anchor growing algorithm --- examples/simple_trainer_scaffold.py | 83 +++++++------ gsplat/rendering.py | 2 +- gsplat/strategy/ops.py | 91 +++++++++++++++ gsplat/strategy/scaffold.py | 173 +++++++++++++++++++++++----- 4 files changed, 280 insertions(+), 69 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 300842750..7afc740ff 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -38,7 +38,7 @@ from gsplat.compression import PngCompression from gsplat.distributed import cli -from gsplat.rendering import rasterization, filter_visible_gaussians +from gsplat.rendering import rasterization, view_to_visible_anchors from gsplat.strategy import ScaffoldStrategy @@ -162,6 +162,7 @@ def adjust_steps(self, factor: float): strategy = self.strategy strategy.refine_start_iter = int(strategy.refine_start_iter * factor) strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) + strategy.voxel_size = self.voxel_size # strategy.reset_every = int(strategy.reset_every * factor) # strategy.refine_every = int(strategy.refine_every * factor) @@ -199,7 +200,7 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] - features = torch.zeros((N, strategy.mean_feat_dim)) + features = torch.zeros((N, strategy.feat_dim)) offsets = torch.zeros((N, strategy.n_feat_offsets, 3)) params = [ @@ -334,7 +335,7 @@ def __init__( if cfg.app_opt: assert feature_dim is not None self.app_module = AppearanceOptModule( - len(self.trainset), feature_dim, cfg.app_embed_dim, cfg.sh_degree + len(self.trainset), feature_dim, cfg.app_embed_dim, None ).to(self.device) # initialize the last layer to be zero so that the initial output is zero. torch.nn.init.zeros_(self.app_module.color_head[-1].weight) @@ -394,7 +395,7 @@ def __init__( mode="training", ) - def get_visibility_mask( + def get_visible_anchor_mask( self, camtoworlds: Tensor, Ks: Tensor, @@ -408,7 +409,7 @@ def get_visibility_mask( quats = self.splats["quats"] # [N, 4] scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3] - visibility_mask = filter_visible_gaussians( + visible_anchor_mask = view_to_visible_anchors( means=anchors, quats=quats, scales=scales, @@ -420,18 +421,18 @@ def get_visibility_mask( rasterize_mode=rasterize_mode, ) - return visibility_mask + return visible_anchor_mask - def get_neural_gaussians(self, cam_pos, selection=None): + def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None): # If no visibility mask is provided, we select all anchors including their offsets - if selection is None: - selection = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) + if visible_anchor_mask is None: + visible_anchor_mask = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) - selected_features = self.splats["features"][selection] # [M, c] - selected_anchors = self.splats["anchors"][selection] # [M, 3] - selected_offsets = self.splats["offsets"][selection] # [M, k, 3] - selected_scales = torch.exp(self.splats["scales"][selection]) # [M, 6] + selected_features = self.splats["features"][visible_anchor_mask] # [M, c] + selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3] + selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3] + selected_scales = torch.exp(self.splats["scales"][visible_anchor_mask]) # [M, 6] # See formula (5) in Scaffold-GS view_dir = selected_anchors - cam_pos # [M, 3] @@ -445,7 +446,7 @@ def get_neural_gaussians(self, cam_pos, selection=None): # Apply MLPs (they output per-offset features concatenated along the last dimension) neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] - pos_opacity_mask = (neural_opacity > 0.0).view(-1) # [M*k] + neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k] # Get color and reshape neural_colors = self.cfg.strategy.colors_mlp(feature_view_dir) # [M, k*3] @@ -461,12 +462,12 @@ def get_neural_gaussians(self, cam_pos, selection=None): anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3] # Apply positive opacity mask - selected_opacity = neural_opacity[pos_opacity_mask].squeeze(-1) # [m] - selected_colors = neural_colors[pos_opacity_mask] # [m, 3] - selected_scale_rot = neural_scale_rot[pos_opacity_mask] # [m, 7] - selected_offsets = selected_offsets[pos_opacity_mask] # [m, 3] - scales_repeated = scales_repeated[pos_opacity_mask] # [m, 6] - anchors_repeated = anchors_repeated[pos_opacity_mask] # [m, 3] + selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [m] + selected_colors = neural_colors[neural_selection_mask] # [m, 3] + selected_scale_rot = neural_scale_rot[neural_selection_mask] # [m, 7] + selected_offsets = selected_offsets[neural_selection_mask] # [m, 3] + scales_repeated = scales_repeated[neural_selection_mask] # [m, 6] + anchors_repeated = anchors_repeated[neural_selection_mask] # [m, 3] # Compute scales and rotations scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3] @@ -474,9 +475,9 @@ def get_neural_gaussians(self, cam_pos, selection=None): # Compute offsets and anchors offsets = selected_offsets * scales_repeated[:, :3] # [m, 3] - anchors = anchors_repeated + offsets # [m, 3] + means = anchors_repeated + offsets # [m, 3] - return anchors, selected_colors, selected_opacity, scales, rotation, neural_opacity, pos_opacity_mask + return means, selected_colors, selected_opacity, scales, rotation, neural_opacity, neural_selection_mask def rasterize_splats( self, @@ -487,7 +488,8 @@ def rasterize_splats( **kwargs, ) -> Tuple[Tensor, Tensor, Dict, Tensor]: - visibility_mask = self.get_visibility_mask(camtoworlds=camtoworlds, + # We select only the visible anchors for faster inference + visible_anchor_mask = self.get_visible_anchor_mask(camtoworlds=camtoworlds, Ks=Ks, width=width, height=height, @@ -495,15 +497,15 @@ def rasterize_splats( rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", ) - anchors, color_mlp, opacities, scales, quats, neural_opacity, selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], selection=visibility_mask) + # Get all the gaussians per voxel spawned from the anchors + means, color_mlp, opacities, scales, quats, neural_opacity, neural_selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], visible_anchor_mask=visible_anchor_mask) image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( features=self.splats["features"], embed_ids=image_ids, - dirs=anchors[None, :, :] - camtoworlds[:, None, :3, 3], - sh_degree=kwargs.pop("sh_degree", self.cfg.sh_degree), + dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], ) colors = colors + color_mlp colors = torch.sigmoid(colors) @@ -512,7 +514,7 @@ def rasterize_splats( rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, info = rasterization( - means=anchors, + means=means, quats=quats, scales=scales, opacities=opacities, @@ -528,6 +530,13 @@ def rasterize_splats( distributed=self.world_size > 1, **kwargs, ) + info.update( + { + "visible_anchor_mask": visible_anchor_mask, + "neural_selection_mask": neural_selection_mask, + "neural_opacities": neural_opacity, + } + ) return render_colors, render_alphas, info, scales def train(self): @@ -755,14 +764,14 @@ def train(self): ) # For now no post steps - # self.cfg.strategy.step_post_backward( - # params=self.splats, - # optimizers=self.optimizers, - # state=self.strategy_state, - # step=step, - # info=info, - # packed=cfg.packed, - # ) + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) # Turn Gradients into Sparse Tensor before running optimizer if cfg.sparse_grad: @@ -842,7 +851,7 @@ def eval(self, step: int, stage: str = "val"): Ks=Ks, width=width, height=height, - sh_degree=cfg.sh_degree, + sh_degree=None, near_plane=cfg.near_plane, far_plane=cfg.far_plane, ) # [1, H, W, 3] @@ -946,7 +955,7 @@ def render_traj(self, step: int): Ks=Ks, width=width, height=height, - sh_degree=cfg.sh_degree, + sh_degree=None, near_plane=cfg.near_plane, far_plane=cfg.far_plane, render_mode="RGB+ED", diff --git a/gsplat/rendering.py b/gsplat/rendering.py index 174f8121d..f93f62c52 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -582,7 +582,7 @@ def reshape_view(C: int, world_view: torch.Tensor, N_world: list) -> torch.Tenso return render_colors, render_alphas, meta -def filter_visible_gaussians( +def view_to_visible_anchors( means: Tensor, # [N, 3] quats: Tensor, # [N, 4] scales: Tensor, # [N, 3] diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 0789dfcbc..9b0fb81d3 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from torch import Tensor +from torch_scatter import scatter_max from gsplat import quat_scale_to_covar_preci from gsplat.relocation import compute_relocation @@ -361,3 +362,93 @@ def op_sigmoid(x, k=100, x0=0.995): ) noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) + + +@torch.no_grad() +def grow_anchors( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + anchors: torch.Tensor, + gradient_mask: torch.Tensor, + remove_duplicates_mask: torch.Tensor, + inv_idx: torch.Tensor, + voxel_size: float, + n_feat_offsets: int, + feat_dim: int, +): + """Inplace add new Gaussians (anchors) to the parameters. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + state: A dictionary of extra state tensors. + anchors: Positions of new anchors to be added. + gradient_mask: A mask to select gradients. + remove_duplicates_mask: A mask to remove duplicates. + inv_idx: Indices for inverse mapping. + voxel_size: The size of the voxel. + n_feat_offsets: Number of feature offsets. + feat_dim: Dimension of features. + """ + device = anchors.device + num_new = anchors.size(0) + + # Scale anchors + anchors = anchors * voxel_size # [N_new, 3] + + # Initialize new parameters + log_voxel_size = torch.log(torch.tensor(voxel_size, device=device)) + scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6] + + rotation = torch.zeros((num_new, 4), device=device) + rotation[:, 0] = 1.0 # Identity quaternion + + # Prepare new features + existing_features = params["features"] # [N_existing, feat_dim] + repeated_features = existing_features.repeat_interleave(n_feat_offsets, dim=0) # [N_existing * n_feat_offsets, feat_dim] + + selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] + + # Use inverse_indices to aggregate features + scattered_features, _ = scatter_max( + selected_features, + inv_idx.unsqueeze(1).expand(-1, feat_dim), + dim=0 + ) + feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim] + + # Initialize new offsets + offsets = torch.zeros((num_new, n_feat_offsets, 3), device=device) # [N_new, n_feat_offsets, 3] + + def param_fn(name: str, p: Tensor) -> Tensor: + if name == "anchors": + p_new = torch.cat([p, anchors], dim=0) + elif name == "scales": + p_new = torch.cat([p, scaling], dim=0) + elif name == "quats": + p_new = torch.cat([p, rotation], dim=0) + elif name == "features": + p_new = torch.cat([p, feat], dim=0) + elif name == "offsets": + p_new = torch.cat([p, offsets], dim=0) + else: + raise ValueError(f"Parameter '{name}' not recognized.") + return torch.nn.Parameter(p_new) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + # Extend optimizer state tensors with zeros + zeros = torch.zeros((num_new, *v.shape[1:]), device=device) + v_new = torch.cat([v, zeros], dim=0) + return v_new + + # Update parameters and optimizer states + names = ["anchors", "scales", "quats", "features", "offsets"] + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + + # Update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor): + zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) + state[k] = torch.cat([v, zeros], dim=0) + diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index b62c777ff..bd72d6004 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -4,7 +4,7 @@ import torch from .base import Strategy -from .ops import duplicate, remove +from .ops import duplicate, remove, grow_anchors @dataclass @@ -79,30 +79,32 @@ class ScaffoldStrategy(Strategy): prune_scale2d: float = 0.15 refine_scale2d_stop_iter: int = 0 refine_start_iter: int = 500 + max_voxel_levels: int = 3 + voxel_size: float = 0.001 refine_stop_iter: int = 15_000 - mean_feat_dim: int = 32 + feat_dim: int = 32 n_feat_offsets: int = 10 refine_every: int = 100 pause_refine_after_reset: int = 0 absgrad: bool = False revised_opacity: bool = False - verbose: bool = False + verbose: bool = True colors_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), - torch.nn.Linear(mean_feat_dim, 3 * n_feat_offsets), + torch.nn.Linear(feat_dim, 3 * n_feat_offsets), torch.nn.Sigmoid() ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), - torch.nn.Linear(mean_feat_dim, n_feat_offsets), + torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh() ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(mean_feat_dim + 3, mean_feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), - torch.nn.Linear(mean_feat_dim, 7 * n_feat_offsets) + torch.nn.Linear(feat_dim, 7 * n_feat_offsets) ).cuda() def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: @@ -116,6 +118,7 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. # - radii: the radii of the GSs (normalized by the image resolution). + # - radii: the radii of the GSs (normalized by the image resolution). state = {"grad2d": None, "count": None, "scene_scale": scene_scale} @@ -193,20 +196,20 @@ def step_post_backward( and step % self.refine_every == 0 ): # grow GSs - n_dupli, n_split = self._grow_anchors(params, optimizers, state, step) + new_anchors = self._anchor_growing(params, optimizers, state, step) if self.verbose: print( - f"Step {step}: {n_dupli} GSs duplicated, {n_split} GSs split. " + f"Step {step}: {new_anchors} anchors grown." f"Now having {len(params['anchors'])} GSs." ) # prune GSs - n_prune = self._prune_gs(params, optimizers, state, step) - if self.verbose: - print( - f"Step {step}: {n_prune} GSs pruned. " - f"Now having {len(params['anchors'])} GSs." - ) + # n_prune = self._prune_gs(params, optimizers, state, step) + # if self.verbose: + # print( + # f"Step {step}: {n_prune} GSs pruned. " + # f"Now having {len(params['anchors'])} GSs." + # ) # reset running stats state["grad2d"].zero_() @@ -223,26 +226,27 @@ def _update_state( info: Dict[str, Any], packed: bool = False, ): - for key in ["anchors2d", "width", "height", "n_cameras", "radii", "gaussian_ids"]: + for key in ["width", "height", "n_cameras", "radii", "gaussian_ids"]: assert key in info, f"{key} is required but missing." # normalize grads to [-1, 1] screen space if self.absgrad: - grads = info["anchors2d"].absgrad.clone() + grads = info["means2d"].absgrad.clone() else: - grads = info["anchors2d"].grad.clone() + grads = info["means2d"].grad.clone() grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] # initialize state on the first run - n_gaussian = len(list(params["anchors"].shape[0])) + n_gaussian = params["anchors"].shape[0] if state["grad2d"] is None: - state["grad2d"] = torch.zeros(n_gaussian, device=grads.device) + state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) if state["count"] is None: - state["count"] = torch.zeros(n_gaussian, device=grads.device) + state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: assert "radii" in info, "radii is required but missing." - state["radii"] = torch.zeros(n_gaussian, device=grads.device) + state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) + # update the running state if packed: @@ -255,11 +259,22 @@ def _update_state( gs_ids = torch.where(sel)[1] # [nnz] grads = grads[sel] # [nnz, 2] radii = info["radii"][sel] # [nnz] + # update neural gaussian statis + visible_anchor_mask = info["visible_anchor_mask"] + neural_selection_mask = info["neural_selection_mask"] + # Extend to + anchor_visible_mask = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.n_feat_offsets]).view(-1) + neural_gaussian_mask = torch.zeros_like(state["grad2d"], dtype=torch.bool) + neural_gaussian_mask[anchor_visible_mask] = neural_selection_mask + valid_mask = neural_gaussian_mask[gs_ids] + + # Filter gs_ids and grads based on the valid_mask + valid_gs_ids = gs_ids[valid_mask] + valid_grads_norm = grads.norm(dim=-1)[valid_mask] + + state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm) + state["count"].index_add_(0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32)) - state["grad2d"].index_add_(0, gs_ids, grads.norm(dim=-1)) - state["count"].index_add_( - 0, gs_ids, torch.ones_like(gs_ids, dtype=torch.float32) - ) if self.refine_scale2d_stop_iter > 0: # Should be ideally using scatter max state["radii"][gs_ids] = torch.maximum( @@ -269,14 +284,110 @@ def _update_state( ) @torch.no_grad() - def _grow_anchors( + def _anchor_growing( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Any], step: int, - ) -> Tuple[int, int]: - pass + ) -> Tuple[int]: + """ + Implements the Anchor Growing algorithm as described in Algorithm 1 of the + GS-Scaffold appendix: + https://openaccess.thecvf.com/content/CVPR2024/supplemental/Lu_Scaffold-GS_Structured_3D_CVPR_2024_supplemental.pdf + + This method performs anchor growing for structured optimization during training, + which helps improve the generalization and stability of the model. + + Args: + params (Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict]): + The model's parameters to be optimized. + optimizers (Dict[str, torch.optim.Optimizer]): + A dictionary of optimizers associated with the model's parameters. + state (Dict[str, Any]): + A dictionary containing the current state of the training process. + step (int): + The current step or iteration of the training process. + + Returns: + Tuple[int]: + A tuple containing the updated step value after anchor growing. + """ + + count = state["count"] + grads = state["grad2d"] / count.clamp_min(1) + grads[grads.isnan()] = 0.0 + device = grads.device + + # is_grad_high = (grads > self.grow_grad2d).squeeze(-1) + # is_small = (n + # torch.exp(params["scales"]).max(dim=-1).values + # <= self.grow_scale3d * state["scene_scale"] + # ) + # is_dupli = is_grad_high & is_small + # n_dupli = is_dupli.sum().item() + + n_init_features = params["features"].shape[0] * self.n_feat_offsets + n_added_anchors = 0 + + # Algorithm 1: Anchor Growing + # Step 1: Initialization + m = 1 # Iteration count (levels) + M = self.max_voxel_levels + tau_g = self.grow_grad2d + epsilon_g = self.voxel_size + new_anchors: torch.tensor + + # Step 2: Iterate while m <= M + while m <= M: + n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features + # Check if anchor candidates have grown + if n_feature_diff == 0 and (m-1) > 0: + break + + # Step 3: Update threshold and voxel size + tau = tau_g * 2 ** (m - 1) + current_voxel_size = (16 // (4 ** (m - 1))) * epsilon_g + + # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) + gradient_mask = grads >= tau + gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0) + neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1) + selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask] + + # Step 5: Merge same positions + selected_grid_coords = torch.round(selected_neural_gaussians / current_voxel_size).int() + selected_grid_coords_unique, inv_idx = torch.unique(selected_grid_coords, return_inverse=True, dim=0) + + # Step 6: Random elimination + # TODO: implement rand elimination (necessary)? + + # Get the grid coordinates of the current anchors + grid_anchor_coords = torch.round(params["anchors"] / current_voxel_size).int() + + # Step 7: Remove occupied by comparing the unique coordinates to current anchors + remove_occupied_pos_mask = ~( + (selected_grid_coords_unique.unsqueeze(1) == grid_anchor_coords).all(-1).any(-1).view(-1)) + + # New anchor candidates are those unique coordinates that are not duplicates + new_anchors = selected_grid_coords_unique[remove_occupied_pos_mask] + + if new_anchors.shape[0] > 0: + grow_anchors(params=params, + optimizers=optimizers, + state=state, + anchors=new_anchors, + gradient_mask=gradient_mask, + remove_duplicates_mask=remove_occupied_pos_mask, + inv_idx=inv_idx, + voxel_size=current_voxel_size, + n_feat_offsets=self.n_feat_offsets, + feat_dim=self.feat_dim) + + n_added_anchors += new_anchors.shape[0] + m += 1 + + return n_added_anchors @torch.no_grad() def _prune_gs( From 1c101f781566d3f266602ed0424e9229ee6c0b39 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 19 Sep 2024 11:40:14 +0200 Subject: [PATCH 03/29] implement the anchor pruning mechanism --- gsplat/strategy/ops.py | 11 +++- gsplat/strategy/scaffold.py | 128 +++++++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 43 deletions(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 9b0fb81d3..bedcb5a1b 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -181,13 +181,16 @@ def remove( optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Tensor], mask: Tensor, + names: Union[List[str], None] = None, ): """Inplace remove the Gaussian with the given mask. Args: params: A dictionary of parameters. optimizers: A dictionary of optimizers, each corresponding to a parameter. + state: A dictionary of extra state tensors. mask: A boolean mask to remove the Gaussians. + names: A list of key names to update. If None, update all. Default: None. """ sel = torch.where(~mask)[0] @@ -198,7 +201,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v[sel] # update the parameters and the state in the optimizers - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -360,7 +363,6 @@ def op_sigmoid(x, k=100, x0=0.995): * (op_sigmoid(1 - opacities)).unsqueeze(-1) * scaler ) - noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) @@ -449,6 +451,9 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: # Update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): - zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) + if k == "anchor_count" or k == "anchor_opacity": + zeros = torch.zeros((num_new, *v.shape[1:]), device=device) + else: + zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) state[k] = torch.cat([v, zeros], dim=0) diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index bd72d6004..f14655fad 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -118,9 +118,10 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. # - radii: the radii of the GSs (normalized by the image resolution). - # - radii: the radii of the GSs (normalized by the image resolution). state = {"grad2d": None, "count": None, + "anchor_count": None, + "anchor_opacity": None, "scene_scale": scene_scale} if self.refine_scale2d_stop_iter > 0: state["radii"] = None @@ -204,51 +205,75 @@ def step_post_backward( ) # prune GSs - # n_prune = self._prune_gs(params, optimizers, state, step) - # if self.verbose: - # print( - # f"Step {step}: {n_prune} GSs pruned. " - # f"Now having {len(params['anchors'])} GSs." - # ) + is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() + mask = state["anchor_opacity"] > self.prune_opa + print(mask.sum().item()) + + n_prune = is_prune.sum().item() + if n_prune > 0: + names = ["anchors", "scales", "quats", "features", "offsets"] + remove(params=params, optimizers=optimizers, state=state, mask=is_prune, names=names) + + if self.verbose: + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(params['anchors'])} GSs." + ) # reset running stats state["grad2d"].zero_() state["count"].zero_() + state["anchor_count"].zero_() + state["anchor_opacity"].zero_() + device = params["anchors"].device + n_gaussian = params["anchors"].shape[0] + state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["anchor_count"] = torch.zeros(n_gaussian, device=device) + state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) if self.refine_scale2d_stop_iter > 0: state["radii"].zero_() torch.cuda.empty_cache() def _update_state( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - state: Dict[str, Any], - info: Dict[str, Any], - packed: bool = False, + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, ): - for key in ["width", "height", "n_cameras", "radii", "gaussian_ids"]: + # Ensure required keys are present + required_keys = ["width", "height", "n_cameras", "radii", "gaussian_ids"] + for key in required_keys: assert key in info, f"{key} is required but missing." - # normalize grads to [-1, 1] screen space + # Normalize gradients to [-1, 1] screen space + scale_factors = torch.tensor( + [info["width"] / 2.0 * info["n_cameras"], info["height"] / 2.0 * info["n_cameras"]], + device=info["means2d"].device, + ) + if self.absgrad: - grads = info["means2d"].absgrad.clone() + grads = info["means2d"].absgrad.detach() * scale_factors else: - grads = info["means2d"].grad.clone() - grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] - grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] + grads = info["means2d"].grad.detach() * scale_factors - # initialize state on the first run + # Initialize state on the first run n_gaussian = params["anchors"].shape[0] + device = grads.device if state["grad2d"] is None: - state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) + state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) if state["count"] is None: - state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) + state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + if state["anchor_count"] is None: + state["anchor_count"] = torch.zeros(n_gaussian, device=device) + if state["anchor_opacity"] is None: + state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: - assert "radii" in info, "radii is required but missing." - state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) - + state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - # update the running state + # Update the running state if packed: # grads is [nnz, 2] gs_ids = info["gaussian_ids"] # [nnz] @@ -256,31 +281,54 @@ def _update_state( else: # grads is [C, N, 2] sel = info["radii"] > 0.0 # [C, N] - gs_ids = torch.where(sel)[1] # [nnz] + gs_ids = sel.nonzero(as_tuple=False)[:, 1] # [nnz] grads = grads[sel] # [nnz, 2] radii = info["radii"][sel] # [nnz] - # update neural gaussian statis + + # Compute valid_mask efficiently visible_anchor_mask = info["visible_anchor_mask"] neural_selection_mask = info["neural_selection_mask"] - # Extend to - anchor_visible_mask = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.n_feat_offsets]).view(-1) - neural_gaussian_mask = torch.zeros_like(state["grad2d"], dtype=torch.bool) - neural_gaussian_mask[anchor_visible_mask] = neural_selection_mask - valid_mask = neural_gaussian_mask[gs_ids] - # Filter gs_ids and grads based on the valid_mask + # Compute anchor indices + anchor_indices = gs_ids // self.n_feat_offsets + + # Determine valid gs_ids based on visibility and selection masks + valid_mask = ( + visible_anchor_mask[anchor_indices] & neural_selection_mask[gs_ids] + ) + + # Filter gs_ids and grads based on valid_mask valid_gs_ids = gs_ids[valid_mask] - valid_grads_norm = grads.norm(dim=-1)[valid_mask] + valid_grads_norm = grads[valid_mask].norm(dim=-1) + # Update state using index_add_ state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm) - state["count"].index_add_(0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32)) - + state["count"].index_add_( + 0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32) + ) + + # Update anchor opacity and count + anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1) + neural_opacities = ( + info["neural_opacities"] + .detach() + .view(-1, self.n_feat_offsets) + .clamp_min_(0) + .sum(dim=1) + ) + + state["anchor_opacity"].index_add_(0, anchor_ids, neural_opacities) + state["anchor_count"].index_add_( + 0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32) + ) + + # Update radii if required if self.refine_scale2d_stop_iter > 0: - # Should be ideally using scatter max + # Normalize radii to [0, 1] screen space + normalized_radii = radii / float(max(info["width"], info["height"])) + # Update radii using torch.maximum state["radii"][gs_ids] = torch.maximum( - state["radii"][gs_ids], - # normalize radii to [0, 1] screen space - radii / float(max(info["width"], info["height"])), + state["radii"][gs_ids], normalized_radii ) @torch.no_grad() From 09ebdfa3360a1151f04c1d4cba61a278360a9ba3 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 20 Sep 2024 09:55:40 +0200 Subject: [PATCH 04/29] better params --- examples/simple_trainer_scaffold.py | 36 +++++++++++++---------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 7afc740ff..9debd839e 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -205,12 +205,12 @@ def create_splats_with_optimizers( params = [ # name, value, lr - ("anchors", torch.nn.Parameter(points), 1.6e-4 * scene_scale), - ("scales", torch.nn.Parameter(scales), 5e-3), - ("quats", torch.nn.Parameter(quats), 1e-3), - ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.02), - ("features", torch.nn.Parameter(features), 1.6e-4 * scene_scale), - ("offsets", torch.nn.Parameter(offsets), 0.004), + ("anchors", torch.nn.Parameter(points), 0), + ("scales", torch.nn.Parameter(scales), 0.007), + ("quats", torch.nn.Parameter(quats), 0.002), + ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), + ("features", torch.nn.Parameter(features), 0.0075 * scene_scale), + ("offsets", torch.nn.Parameter(offsets), 0.01), ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), ] @@ -554,9 +554,14 @@ def train(self): init_step = 0 schedulers = [ - # anchors has a learning rate schedule, that end at 0.01 of the initial value torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["anchors"], gamma=0.01 ** (1.0 / max_steps) + self.optimizers["offsets"], gamma=0.01 ** (1.0 / max_steps) + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["opacities_mlp"], gamma=0.002 ** (1.0 / max_steps) + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["colors_mlp"], gamma=0.008 ** (1.0 / max_steps) ), ] if cfg.pose_opt: @@ -697,15 +702,6 @@ def train(self): tvloss = 10 * total_variation_loss(self.bil_grids.grids) loss += tvloss - # regularizations - # not gonna work. Check this - # if cfg.opacity_reg > 0.0: - # loss = ( - # loss - # + cfg.opacity_reg - # * torch.abs(torch.sigmoid(self.splats["opacities"])).mean() - # ) - loss.backward() desc = f"loss={loss.item():.3f}| " @@ -805,9 +801,9 @@ def train(self): scheduler.step() # eval the full set - if step in [i - 1 for i in cfg.eval_steps]: - self.eval(step) - self.render_traj(step) + # if step in [i - 1 for i in cfg.eval_steps]: + # self.eval(step) + # self.render_traj(step) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: From 3423efb13d9896bedf91cd7c6f2309e3af957614 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 20 Sep 2024 10:36:09 +0200 Subject: [PATCH 05/29] add observation threshold --- gsplat/strategy/scaffold.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index f14655fad..740ab51b7 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -85,22 +85,31 @@ class ScaffoldStrategy(Strategy): feat_dim: int = 32 n_feat_offsets: int = 10 refine_every: int = 100 + + # 3.3 Observation Thresholds (compare paper) + # "To enhance the robustness of the Growing and Pruning operations for long image sequences, ..." + pruning_thresholds: float = 0.8 + growing_thresholds: float = 0.4 + pause_refine_after_reset: int = 0 absgrad: bool = False revised_opacity: bool = False verbose: bool = True + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 3 * n_feat_offsets), torch.nn.Sigmoid() ).cuda() + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh() ).cuda() + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), @@ -201,13 +210,12 @@ def step_post_backward( if self.verbose: print( f"Step {step}: {new_anchors} anchors grown." - f"Now having {len(params['anchors'])} GSs." + f"Now having {len(params['anchors'])} anchors." ) # prune GSs is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() - mask = state["anchor_opacity"] > self.prune_opa - print(mask.sum().item()) + is_prune = torch.logical_and(is_prune, state["anchor_count"] > self.refine_every * self.pruning_thresholds) n_prune = is_prune.sum().item() if n_prune > 0: @@ -216,8 +224,8 @@ def step_post_backward( if self.verbose: print( - f"Step {step}: {n_prune} GSs pruned. " - f"Now having {len(params['anchors'])} GSs." + f"Step {step}: {n_prune} anchors pruned. " + f"Now having {len(params['anchors'])} anchors." ) # reset running stats @@ -386,6 +394,7 @@ def _anchor_growing( epsilon_g = self.voxel_size new_anchors: torch.tensor + growing_treshold = state["count"] > self.refine_every * self.growing_thresholds # Step 2: Iterate while m <= M while m <= M: n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features @@ -399,6 +408,8 @@ def _anchor_growing( # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) gradient_mask = grads >= tau + # + gradient_mask = torch.logical_and(gradient_mask, growing_treshold) gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0) neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1) selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask] From 8c5b0c5bb2ed5f71d403aad657fb14cc2f999e20 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 20 Sep 2024 11:38:14 +0200 Subject: [PATCH 06/29] Should be properly working up to param optimization --- examples/simple_trainer_scaffold.py | 193 +++++++------- gsplat/compression/png_compression.py | 4 +- gsplat/strategy/ops.py | 77 ++++-- gsplat/strategy/scaffold.py | 357 +++++++++++++------------- 4 files changed, 350 insertions(+), 281 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 9debd839e..bc4a41dfd 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,7 +54,8 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/data/bike" + # data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" + data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset data_factor: int = 1 # Directory to save results @@ -88,8 +89,6 @@ class Config: # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm init_extent: float = 3.0 # Turn on another SH degree every this steps - sh_degree_interval: int = 1000 - # Initial opacity of GS init_opa: float = 0.1 # Initial scale of GS init_scale: float = 1.0 @@ -113,8 +112,6 @@ class Config: # Use random background for training to discourage transparency random_bkgd: bool = False - # Opacity regularization - opacity_reg: float = 0.0 # Scale regularization scale_reg: float = 0.01 @@ -157,7 +154,6 @@ def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] self.max_steps = int(self.max_steps * factor) - self.sh_degree_interval = int(self.sh_degree_interval * factor) strategy = self.strategy strategy.refine_start_iter = int(strategy.refine_start_iter * factor) @@ -185,7 +181,7 @@ def create_splats_with_optimizers( ) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: # Compare GS-Scaffold paper formula (4) - points = np.unique(np.round(parser.points/voxel_size), axis=0)*voxel_size + points = np.unique(np.round(parser.points / voxel_size), axis=0) * voxel_size points = torch.from_numpy(points).float() # Initialize the GS size to be the average dist of the 3 nearest neighbors @@ -196,7 +192,6 @@ def create_splats_with_optimizers( # Distribute the GSs to different ranks (also works for single rank) points = points[world_rank::world_size] scales = scales[world_rank::world_size] - N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] @@ -209,8 +204,8 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 0.007), ("quats", torch.nn.Parameter(quats), 0.002), ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), - ("features", torch.nn.Parameter(features), 0.0075 * scene_scale), - ("offsets", torch.nn.Parameter(offsets), 0.01), + ("features", torch.nn.Parameter(features), 0.0075), + ("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale), ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), ] @@ -242,6 +237,7 @@ def __init__( set_random_seed(42 + local_rank) self.cfg = cfg + self.cfg.strategy.voxel_size = self.cfg.voxel_size self.world_rank = world_rank self.local_rank = local_rank self.world_size = world_size @@ -395,24 +391,21 @@ def __init__( mode="training", ) - def get_visible_anchor_mask( - self, - camtoworlds: Tensor, - Ks: Tensor, - width: int, - height: int, - packed: bool, - rasterize_mode: str, -): - anchors = self.splats["anchors"] # [N, 3] - # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3] + def get_neural_gaussians( + self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + packed: bool, + rasterize_mode: str, + ): + # Compare paper: Helps mainly to speed up the rasterization. Has no quality impact visible_anchor_mask = view_to_visible_anchors( - means=anchors, - quats=quats, - scales=scales, + means=self.splats["anchors"], + quats=self.splats["quats"], + scales=torch.exp(self.splats["scales"])[:, :3], viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] width=width, @@ -420,26 +413,27 @@ def get_visible_anchor_mask( packed=packed, rasterize_mode=rasterize_mode, ) - - return visible_anchor_mask - - def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None): - # If no visibility mask is provided, we select all anchors including their offsets - if visible_anchor_mask is None: - visible_anchor_mask = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) - selected_features = self.splats["features"][visible_anchor_mask] # [M, c] selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3] selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3] - selected_scales = torch.exp(self.splats["scales"][visible_anchor_mask]) # [M, 6] + selected_scales = torch.exp( + self.splats["scales"][visible_anchor_mask] + ) # [M, 6] # See formula (5) in Scaffold-GS + + cam_pos = camtoworlds[:, :3, 3] view_dir = selected_anchors - cam_pos # [M, 3] - view_dir_normalized = view_dir / view_dir.norm(dim=1, keepdim=True) # [M, 3] + length = view_dir.norm(dim=1, keepdim=True) + view_dir_normalized = view_dir / length # [M, 3] + + view_length = torch.cat([view_dir_normalized, length], dim=1) # See formula (9) and the appendix for the rest - feature_view_dir = torch.cat([selected_features, view_dir_normalized], dim=1) # [M, c+3] + feature_view_dir = torch.cat( + [selected_features, view_length], dim=1 + ) # [M, c+3] k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor @@ -458,26 +452,50 @@ def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None): # Reshape selected_offsets, scales, and anchors selected_offsets = selected_offsets.view(-1, 3) # [M*k, 3] - scales_repeated = selected_scales.unsqueeze(1).repeat(1, k, 1).view(-1, 6) # [M*k, 6] - anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3] + scales_repeated = ( + selected_scales.unsqueeze(1).repeat(1, k, 1).view(-1, 6) + ) # [M*k, 6] + anchors_repeated = ( + selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) + ) # [M*k, 3] # Apply positive opacity mask - selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [m] - selected_colors = neural_colors[neural_selection_mask] # [m, 3] - selected_scale_rot = neural_scale_rot[neural_selection_mask] # [m, 7] - selected_offsets = selected_offsets[neural_selection_mask] # [m, 3] - scales_repeated = scales_repeated[neural_selection_mask] # [m, 6] - anchors_repeated = anchors_repeated[neural_selection_mask] # [m, 3] + selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [M] + selected_colors = neural_colors[neural_selection_mask] # [M, 3] + selected_scale_rot = neural_scale_rot[neural_selection_mask] # [M, 7] + selected_offsets = selected_offsets[neural_selection_mask] # [M, 3] + scales_repeated = scales_repeated[neural_selection_mask] # [M, 6] + anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3] # Compute scales and rotations - scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3] - rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4] + scales = scales_repeated[:, 3:] * torch.sigmoid( + selected_scale_rot[:, :3] + ) # [M, 3] + rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [M, 4] # Compute offsets and anchors - offsets = selected_offsets * scales_repeated[:, :3] # [m, 3] - means = anchors_repeated + offsets # [m, 3] + offsets = selected_offsets * scales_repeated[:, :3] # [M, 3] + means = anchors_repeated + offsets # [M, 3] - return means, selected_colors, selected_opacity, scales, rotation, neural_opacity, neural_selection_mask + v_a = ( + visible_anchor_mask.unsqueeze(dim=1) + .repeat([1, self.cfg.strategy.n_feat_offsets]) + .view(-1) + ) + all_neural_gaussians = torch.zeros_like(v_a, dtype=torch.bool) + all_neural_gaussians[v_a] = neural_selection_mask + + info = { + "means": means, + "colors": selected_colors, + "opacities": selected_opacity, + "scales": scales, + "quats": rotation, + "neural_opacities": neural_opacity, + "neural_selection_mask": all_neural_gaussians, + "visible_anchor_mask": visible_anchor_mask, + } + return info def rasterize_splats( self, @@ -486,38 +504,36 @@ def rasterize_splats( width: int, height: int, **kwargs, - ) -> Tuple[Tensor, Tensor, Dict, Tensor]: - - # We select only the visible anchors for faster inference - visible_anchor_mask = self.get_visible_anchor_mask(camtoworlds=camtoworlds, - Ks=Ks, - width=width, - height=height, - packed=self.cfg.packed, - rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", - ) + ) -> Tuple[Tensor, Tensor, Dict]: # Get all the gaussians per voxel spawned from the anchors - means, color_mlp, opacities, scales, quats, neural_opacity, neural_selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], visible_anchor_mask=visible_anchor_mask) + info = self.get_neural_gaussians( + camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + packed=self.cfg.packed, + rasterize_mode="antialiased" if self.cfg.antialiased else "classic", + ) image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( features=self.splats["features"], embed_ids=image_ids, - dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3], ) - colors = colors + color_mlp + colors = colors + info["colors"] colors = torch.sigmoid(colors) else: - colors = color_mlp # [N, K, 3] + colors = info["colors"] # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - render_colors, render_alphas, info = rasterization( - means=means, - quats=quats, - scales=scales, - opacities=opacities, + render_colors, render_alphas, raster_info = rasterization( + means=info["means"], + quats=info["quats"], + scales=info["scales"], + opacities=info["opacities"], colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] @@ -530,14 +546,8 @@ def rasterize_splats( distributed=self.world_size > 1, **kwargs, ) - info.update( - { - "visible_anchor_mask": visible_anchor_mask, - "neural_selection_mask": neural_selection_mask, - "neural_opacities": neural_opacity, - } - ) - return render_colors, render_alphas, info, scales + raster_info.update(info) + return render_colors, render_alphas, raster_info def train(self): cfg = self.cfg @@ -555,13 +565,17 @@ def train(self): schedulers = [ torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["offsets"], gamma=0.01 ** (1.0 / max_steps) + self.optimizers["anchors"], gamma=0.001 ** (1.0 / max_steps) + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["offsets"], + gamma=(0.01 * self.scene_scale) ** (1.0 / max_steps), ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["opacities_mlp"], gamma=0.002 ** (1.0 / max_steps) + self.optimizers["opacities_mlp"], gamma=0.001 ** (1.0 / max_steps) ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["colors_mlp"], gamma=0.008 ** (1.0 / max_steps) + self.optimizers["colors_mlp"], gamma=0.00625 ** (1.0 / max_steps) ), ] if cfg.pose_opt: @@ -601,6 +615,11 @@ def train(self): # Training loop. global_tic = time.time() pbar = tqdm.tqdm(range(init_step, max_steps)) + + self.cfg.strategy.scale_rot_mlp.train() + self.cfg.strategy.opacities_mlp.train() + self.cfg.strategy.colors_mlp.train() + for step in pbar: if not cfg.disable_viewer: while self.viewer.state.status == "paused": @@ -634,7 +653,7 @@ def train(self): camtoworlds = self.pose_adjust(camtoworlds, image_ids) # forward - renders, alphas, info, scales = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -678,7 +697,7 @@ def train(self): ) loss = l1loss * (1.0 - cfg.ssim_lambda) loss += ssimloss * cfg.ssim_lambda - loss += scales.prod(dim=1).mean() * cfg.scale_reg + loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg if cfg.depth_loss: # query depths from depth map points = torch.stack( @@ -718,7 +737,9 @@ def train(self): self.writer.add_scalar("train/loss", loss.item(), step) self.writer.add_scalar("train/l1loss", l1loss.item(), step) self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) - self.writer.add_scalar("train/num_GS", len(self.splats["anchors"]), step) + self.writer.add_scalar( + "train/num_GS", len(self.splats["anchors"]), step + ) self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: self.writer.add_scalar("train/depthloss", depthloss.item(), step) @@ -842,7 +863,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _, _ = self.rasterize_splats( + colors, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -946,7 +967,7 @@ def render_traj(self, step: int): camtoworlds = camtoworlds_all[i : i + 1] Ks = K[None] - renders, _, _, _ = self.rasterize_splats( + renders, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -1003,7 +1024,7 @@ def _viewer_render_fn( c2w = torch.from_numpy(c2w).float().to(self.device) K = torch.from_numpy(K).float().to(self.device) - render_colors, _, _, _ = self.rasterize_splats( + render_colors, _, _ = self.rasterize_splats( camtoworlds=c2w[None], Ks=K[None], width=W, diff --git a/gsplat/compression/png_compression.py b/gsplat/compression/png_compression.py index ee4ead6f9..2dbece8cf 100644 --- a/gsplat/compression/png_compression.py +++ b/gsplat/compression/png_compression.py @@ -368,9 +368,7 @@ def _compress_kmeans( maxs = torch.max(centroids) centroids_norm = (centroids - mins) / (maxs - mins) centroids_norm = centroids_norm.detach().cpu().numpy() - centroids_quant = ( - (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) - ) + centroids_quant = (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) labels = labels.astype(np.uint16) npz_dict = { diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index bedcb5a1b..178ab2423 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -368,16 +368,16 @@ def op_sigmoid(x, k=100, x0=0.995): @torch.no_grad() def grow_anchors( - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - state: Dict[str, Tensor], - anchors: torch.Tensor, - gradient_mask: torch.Tensor, - remove_duplicates_mask: torch.Tensor, - inv_idx: torch.Tensor, - voxel_size: float, - n_feat_offsets: int, - feat_dim: int, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + state: Dict[str, Tensor], + anchors: torch.Tensor, + gradient_mask: torch.Tensor, + remove_duplicates_mask: torch.Tensor, + inv_idx: torch.Tensor, + voxel_size: float, + n_feat_offsets: int, + feat_dim: int, ): """Inplace add new Gaussians (anchors) to the parameters. @@ -403,25 +403,26 @@ def grow_anchors( log_voxel_size = torch.log(torch.tensor(voxel_size, device=device)) scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6] - rotation = torch.zeros((num_new, 4), device=device) - rotation[:, 0] = 1.0 # Identity quaternion + rotation = torch.ones((num_new, 4), device=device) # Prepare new features existing_features = params["features"] # [N_existing, feat_dim] - repeated_features = existing_features.repeat_interleave(n_feat_offsets, dim=0) # [N_existing * n_feat_offsets, feat_dim] + repeated_features = existing_features.repeat_interleave( + n_feat_offsets, dim=0 + ) # [N_existing * n_feat_offsets, feat_dim] selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] # Use inverse_indices to aggregate features scattered_features, _ = scatter_max( - selected_features, - inv_idx.unsqueeze(1).expand(-1, feat_dim), - dim=0 + selected_features, inv_idx.unsqueeze(1).expand(-1, feat_dim), dim=0 ) feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim] # Initialize new offsets - offsets = torch.zeros((num_new, n_feat_offsets, 3), device=device) # [N_new, n_feat_offsets, 3] + offsets = torch.zeros( + (num_new, n_feat_offsets, 3), device=device + ) # [N_new, n_feat_offsets, 3] def param_fn(name: str, p: Tensor) -> Tensor: if name == "anchors": @@ -454,6 +455,46 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: if k == "anchor_count" or k == "anchor_opacity": zeros = torch.zeros((num_new, *v.shape[1:]), device=device) else: - zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) + zeros = torch.zeros( + (num_new * n_feat_offsets, *v.shape[1:]), device=device + ) state[k] = torch.cat([v, zeros], dim=0) + +@torch.no_grad() +def remove_anchors( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + n_feat_offsets: int, + state: Dict[str, Tensor], + mask: Tensor, + names: Union[List[str], None] = None, +): + """Inplace remove the Gaussian with the given mask. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + n_feat_offsets: Number of feature offsets. + state: A dictionary of extra state tensors. + mask: A boolean mask to remove the Gaussians. + names: A list of parameter names to update. If None, update all. Default: None. + """ + sel = torch.where(~mask)[0] + + def param_fn(name: str, p: Tensor) -> Tensor: + return torch.nn.Parameter(p[sel]) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + return v[sel] + + # update the parameters and the state in the optimizers + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + # update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor): + if k in ["anchor_count", "anchor_opacity"]: + state[k] = v[sel] + else: + offset_sel = sel.unsqueeze(dim=1).repeat([1, n_feat_offsets]).view(-1) + state[k] = v[offset_sel] diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 740ab51b7..c3745dadb 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -1,10 +1,9 @@ from dataclasses import dataclass -from typing import Any, Dict, Tuple, Union - +from typing import Any, Dict, Union import torch from .base import Strategy -from .ops import duplicate, remove, grow_anchors +from .ops import remove_anchors, grow_anchors @dataclass @@ -41,8 +40,6 @@ class ScaffoldStrategy(Strategy): value will be pruned. Default is 0.1. prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above this value will be pruned. Default is 0.15. - refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this - iteration. Default is 0. Set to a positive value to enable this feature. refine_start_iter (int): Start refining GSs after this iteration. Default is 500. refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. refine_every (int): Refine GSs every this steps. Default is 100. @@ -77,8 +74,7 @@ class ScaffoldStrategy(Strategy): grow_scale2d: float = 0.05 prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 - refine_scale2d_stop_iter: int = 0 - refine_start_iter: int = 500 + refine_start_iter: int = 1500 max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 @@ -96,24 +92,25 @@ class ScaffoldStrategy(Strategy): revised_opacity: bool = False verbose: bool = True + view_distance = 1 colors_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 3 * n_feat_offsets), - torch.nn.Sigmoid() + torch.nn.Sigmoid(), ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), - torch.nn.Tanh() + torch.nn.Tanh(), ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), - torch.nn.Linear(feat_dim, 7 * n_feat_offsets) + torch.nn.Linear(feat_dim, 7 * n_feat_offsets), ).cuda() def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: @@ -126,14 +123,13 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # put them on the correct device. # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. - # - radii: the radii of the GSs (normalized by the image resolution). - state = {"grad2d": None, - "count": None, - "anchor_count": None, - "anchor_opacity": None, - "scene_scale": scene_scale} - if self.refine_scale2d_stop_iter > 0: - state["radii"] = None + state = { + "grad2d": None, + "count": None, + "anchor_count": None, + "anchor_opacity": None, + "scene_scale": scene_scale, + } return state def check_sanity( @@ -159,16 +155,20 @@ def check_sanity( super().check_sanity(params, optimizers) # The following keys are required for this strategy. - expected_params = ["anchors", - "features", - "offsets", - "scales", - "quats", - "opacities_mlp", - "colors_mlp", - "scale_rot_mlp"] - - assert len(expected_params) == len(params), "expected params and actual params don't match" + expected_params = [ + "anchors", + "features", + "offsets", + "scales", + "quats", + "opacities_mlp", + "colors_mlp", + "scale_rot_mlp", + ] + + assert len(expected_params) == len( + params + ), "expected params and actual params don't match" for key in expected_params: assert key in params, f"{key} is required in params but missing." @@ -199,28 +199,44 @@ def step_post_backward( if step >= self.refine_stop_iter: return - self._update_state(params, state, info, packed=packed) + if step > 500: + self._update_state(params, state, info, packed=packed) - if ( - step > self.refine_start_iter - and step % self.refine_every == 0 - ): + if step > self.refine_start_iter and step % self.refine_every == 0: # grow GSs - new_anchors = self._anchor_growing(params, optimizers, state, step) + print(f"init anchors: {len(params['anchors'])}") + new_anchors = self._anchor_growing(params, optimizers, state) if self.verbose: print( f"Step {step}: {new_anchors} anchors grown." f"Now having {len(params['anchors'])} anchors." ) - # prune GSs - is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() - is_prune = torch.logical_and(is_prune, state["anchor_count"] > self.refine_every * self.pruning_thresholds) + # prune anchors + low_opacity_mask = ( + state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + ).squeeze() + anchor_mask = ( + state["anchor_count"] > self.pruning_thresholds * self.refine_every + ) # [N, 1] + is_prune = torch.logical_and(low_opacity_mask, anchor_mask) + + indices = anchor_mask.nonzero(as_tuple=False).squeeze() + # Efficiently set the specified indices to zero + state["anchor_count"].index_fill_(0, indices, 0) + state["anchor_opacity"].index_fill_(0, indices, 0) n_prune = is_prune.sum().item() if n_prune > 0: names = ["anchors", "scales", "quats", "features", "offsets"] - remove(params=params, optimizers=optimizers, state=state, mask=is_prune, names=names) + remove_anchors( + params=params, + optimizers=optimizers, + n_feat_offsets=self.n_feat_offsets, + state=state, + mask=is_prune, + names=names, + ) if self.verbose: print( @@ -228,28 +244,14 @@ def step_post_backward( f"Now having {len(params['anchors'])} anchors." ) - # reset running stats - state["grad2d"].zero_() - state["count"].zero_() - state["anchor_count"].zero_() - state["anchor_opacity"].zero_() - device = params["anchors"].device - n_gaussian = params["anchors"].shape[0] - state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - state["anchor_count"] = torch.zeros(n_gaussian, device=device) - state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) - - if self.refine_scale2d_stop_iter > 0: - state["radii"].zero_() torch.cuda.empty_cache() def _update_state( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - state: Dict[str, Any], - info: Dict[str, Any], - packed: bool = False, + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, ): # Ensure required keys are present required_keys = ["width", "height", "n_cameras", "radii", "gaussian_ids"] @@ -257,8 +259,9 @@ def _update_state( assert key in info, f"{key} is required but missing." # Normalize gradients to [-1, 1] screen space + factor = 0.5 * info["n_cameras"] scale_factors = torch.tensor( - [info["width"] / 2.0 * info["n_cameras"], info["height"] / 2.0 * info["n_cameras"]], + [info["width"] * factor, info["height"] * factor], device=info["means2d"].device, ) @@ -271,51 +274,38 @@ def _update_state( n_gaussian = params["anchors"].shape[0] device = grads.device if state["grad2d"] is None: - state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["grad2d"] = torch.zeros( + n_gaussian * self.n_feat_offsets, device=device + ) if state["count"] is None: - state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["count"] = torch.zeros( + n_gaussian * self.n_feat_offsets, device=device + ) if state["anchor_count"] is None: state["anchor_count"] = torch.zeros(n_gaussian, device=device) if state["anchor_opacity"] is None: state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) - if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: - state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - - # Update the running state - if packed: - # grads is [nnz, 2] - gs_ids = info["gaussian_ids"] # [nnz] - radii = info["radii"] # [nnz] - else: - # grads is [C, N, 2] - sel = info["radii"] > 0.0 # [C, N] - gs_ids = sel.nonzero(as_tuple=False)[:, 1] # [nnz] - grads = grads[sel] # [nnz, 2] - radii = info["radii"][sel] # [nnz] - # Compute valid_mask efficiently - visible_anchor_mask = info["visible_anchor_mask"] neural_selection_mask = info["neural_selection_mask"] - # Compute anchor indices - anchor_indices = gs_ids // self.n_feat_offsets - - # Determine valid gs_ids based on visibility and selection masks - valid_mask = ( - visible_anchor_mask[anchor_indices] & neural_selection_mask[gs_ids] - ) + # Update the running state + sel = info["radii"] > 0.0 # [C, N] + neural_ids = neural_selection_mask.nonzero(as_tuple=False).squeeze(-1) + grads = grads[sel].norm(dim=-1) # [nnz, 2] + sel = sel.squeeze(0) - # Filter gs_ids and grads based on valid_mask - valid_gs_ids = gs_ids[valid_mask] - valid_grads_norm = grads[valid_mask].norm(dim=-1) + valid_ids = neural_ids[sel] # Update state using index_add_ - state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm) + state["grad2d"].index_add_(0, valid_ids, grads) state["count"].index_add_( - 0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32) + 0, + valid_ids, + torch.ones((valid_ids.shape[0]), dtype=torch.float32, device=device), ) # Update anchor opacity and count + visible_anchor_mask = info["visible_anchor_mask"] anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1) neural_opacities = ( info["neural_opacities"] @@ -330,23 +320,13 @@ def _update_state( 0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32) ) - # Update radii if required - if self.refine_scale2d_stop_iter > 0: - # Normalize radii to [0, 1] screen space - normalized_radii = radii / float(max(info["width"], info["height"])) - # Update radii using torch.maximum - state["radii"][gs_ids] = torch.maximum( - state["radii"][gs_ids], normalized_radii - ) - @torch.no_grad() def _anchor_growing( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Any], - step: int, - ) -> Tuple[int]: + ) -> int: """ Implements the Anchor Growing algorithm as described in Algorithm 1 of the GS-Scaffold appendix: @@ -362,12 +342,10 @@ def _anchor_growing( A dictionary of optimizers associated with the model's parameters. state (Dict[str, Any]): A dictionary containing the current state of the training process. - step (int): - The current step or iteration of the training process. Returns: - Tuple[int]: - A tuple containing the updated step value after anchor growing. + int: + Number of growned anchors. """ count = state["count"] @@ -375,105 +353,136 @@ def _anchor_growing( grads[grads.isnan()] = 0.0 device = grads.device - # is_grad_high = (grads > self.grow_grad2d).squeeze(-1) - # is_small = (n - # torch.exp(params["scales"]).max(dim=-1).values - # <= self.grow_scale3d * state["scene_scale"] - # ) - # is_dupli = is_grad_high & is_small - # n_dupli = is_dupli.sum().item() - - n_init_features = params["features"].shape[0] * self.n_feat_offsets + n_init_features = state["count"].shape[0] n_added_anchors = 0 # Algorithm 1: Anchor Growing # Step 1: Initialization - m = 1 # Iteration count (levels) + m = 1 # Iteration count (levels) M = self.max_voxel_levels tau_g = self.grow_grad2d epsilon_g = self.voxel_size new_anchors: torch.tensor - growing_treshold = state["count"] > self.refine_every * self.growing_thresholds + growing_threshold_mask = ( + state["count"] > self.refine_every * self.growing_thresholds + ) # Step 2: Iterate while m <= M while m <= M: - n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features - # Check if anchor candidates have grown - if n_feature_diff == 0 and (m-1) > 0: + n_feature_diff = state["count"].shape[0] - n_init_features + # Check if anchor candidates have grown + if n_feature_diff == 0 and m > 1: break # Step 3: Update threshold and voxel size - tau = tau_g * 2 ** (m - 1) + tau = tau_g * (2 ** (m - 1)) current_voxel_size = (16 // (4 ** (m - 1))) * epsilon_g # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) gradient_mask = grads >= tau - # - gradient_mask = torch.logical_and(gradient_mask, growing_treshold) - gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0) - neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1) - selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask] + gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) + + # Drop-out: Helps prevent too massive anchor growth. + rand_mask = torch.rand_like(gradient_mask.float()) > (0.4**m) + rand_mask = rand_mask.cuda() + gradient_mask = torch.logical_and(gradient_mask, rand_mask) + gradient_mask = torch.cat( + [ + gradient_mask, + torch.zeros(n_feature_diff, dtype=torch.bool, device=device), + ], + dim=0, + ) - # Step 5: Merge same positions - selected_grid_coords = torch.round(selected_neural_gaussians / current_voxel_size).int() - selected_grid_coords_unique, inv_idx = torch.unique(selected_grid_coords, return_inverse=True, dim=0) + # Compute neural gaussians + neural_gaussians = params["anchors"].unsqueeze(dim=1) + params[ + "offsets" + ] * torch.exp(params["scales"][:, :3]).unsqueeze(dim=1) + selected_neural_gaussians = neural_gaussians.view([-1, 3])[gradient_mask] - # Step 6: Random elimination - # TODO: implement rand elimination (necessary)? + # Step 5: Merge same positions + selected_grid_coords = torch.round( + selected_neural_gaussians / current_voxel_size + ).int() + selected_grid_coords_unique, inv_idx = torch.unique( + selected_grid_coords, return_inverse=True, dim=0 + ) # Get the grid coordinates of the current anchors - grid_anchor_coords = torch.round(params["anchors"] / current_voxel_size).int() + grid_anchor_coords = torch.round( + params["anchors"] / current_voxel_size + ).int() + + # Step 6: Remove occupied by comparing the unique coordinates to current anchors + def coords_to_indices(coords, N): + """ + Maps quantized multi-dimensional coordinates to unique 1D indices without using torch.matmul. + + Args: + coords (torch.Tensor): Tensor of shape [num_points, D], where D is the dimensionality. + N (int): A large enough number to ensure uniqueness of indices. + + Returns: + torch.Tensor: Tensor of unique indices corresponding to the coordinates. + """ + D = coords.shape[1] + device = coords.device + dtype = coords.dtype # Keep the original data type + + # Compute N_powers as integers + N_powers = N ** torch.arange(D - 1, -1, -1, device=device, dtype=dtype) + + # Perform element-wise multiplication and sum along the last dimension + indices = (coords * N_powers).sum(dim=1) + + return indices + + # Quantize the coordinates + decimal_places = 6 # precision + scale = 10**decimal_places + + # Quantize the coordinates by scaling and converting to integers + selected_coords_quant = (selected_grid_coords_unique * scale).long() + anchor_coords_quant = (grid_anchor_coords * scale).long() + + # Compute the maximum coordinate value + max_coord_value = ( + torch.max(torch.cat([selected_coords_quant, anchor_coords_quant])) + 1 + ) + N = max_coord_value.item() + + # Compute unique indices for both coordinate sets + indices_selected = coords_to_indices(selected_coords_quant, N) + indices_anchor = coords_to_indices(anchor_coords_quant, N) - # Step 7: Remove occupied by comparing the unique coordinates to current anchors - remove_occupied_pos_mask = ~( - (selected_grid_coords_unique.unsqueeze(1) == grid_anchor_coords).all(-1).any(-1).view(-1)) + remove_occupied_pos_mask = ~torch.isin(indices_selected, indices_anchor) # New anchor candidates are those unique coordinates that are not duplicates new_anchors = selected_grid_coords_unique[remove_occupied_pos_mask] if new_anchors.shape[0] > 0: - grow_anchors(params=params, - optimizers=optimizers, - state=state, - anchors=new_anchors, - gradient_mask=gradient_mask, - remove_duplicates_mask=remove_occupied_pos_mask, - inv_idx=inv_idx, - voxel_size=current_voxel_size, - n_feat_offsets=self.n_feat_offsets, - feat_dim=self.feat_dim) + grow_anchors( + params=params, + optimizers=optimizers, + state=state, + anchors=new_anchors, + gradient_mask=gradient_mask, + remove_duplicates_mask=remove_occupied_pos_mask, + inv_idx=inv_idx, + voxel_size=current_voxel_size, + n_feat_offsets=self.n_feat_offsets, + feat_dim=self.feat_dim, + ) n_added_anchors += new_anchors.shape[0] m += 1 - return n_added_anchors + indices = torch.arange(n_init_features, device=growing_threshold_mask.device)[ + growing_threshold_mask + ] - @torch.no_grad() - def _prune_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - state: Dict[str, Any], - step: int, - ) -> int: - is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa - if step > self.reset_every: - is_too_big = ( - torch.exp(params["scales"]).max(dim=-1).values - > self.prune_scale3d * state["scene_scale"] - ) - # The official code also implements sreen-size pruning but - # it's actually not being used due to a bug: - # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 - # We implement it here for completeness but set `refine_scale2d_stop_iter` - # to 0 by default to disable it. - if step < self.refine_scale2d_stop_iter: - is_too_big |= state["radii"] > self.prune_scale2d + if indices.numel() > 0: + state["count"].index_fill_(0, indices, 0) + state["grad2d"].index_fill_(0, indices, 0) - is_prune = is_prune | is_too_big - - n_prune = is_prune.sum().item() - if n_prune > 0: - remove(params=params, optimizers=optimizers, state=state, mask=is_prune) - - return n_prune + return n_added_anchors From 1cb2247c1f6c4fcc91bc0ea2ab7dd8c9379b80f2 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 24 Sep 2024 20:18:02 +0200 Subject: [PATCH 07/29] use reloc instead of pruning --- examples/simple_trainer_scaffold.py | 18 ++-- gsplat/strategy/ops.py | 17 ++-- gsplat/strategy/scaffold.py | 124 +++++++++++++++++++++------- 3 files changed, 114 insertions(+), 45 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index bc4a41dfd..4150cb7f4 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -51,13 +51,13 @@ class Config: # Name of compression strategy to use compression: Optional[Literal["png"]] = None # Render trajectory path - render_traj_path: str = "interp" + render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset - # data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" - data_dir: str = "/home/paja/data/bike_aliked" + data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" + #data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset - data_factor: int = 1 + data_factor: int = 2 # Directory to save results result_dir: str = "results" # Every N images there is a test image @@ -198,6 +198,7 @@ def create_splats_with_optimizers( features = torch.zeros((N, strategy.feat_dim)) offsets = torch.zeros((N, strategy.n_feat_offsets, 3)) + opacities = torch.logit(torch.full((N, 1), init_opacity)) # [N,] params = [ # name, value, lr ("anchors", torch.nn.Parameter(points), 0), @@ -205,6 +206,7 @@ def create_splats_with_optimizers( ("quats", torch.nn.Parameter(quats), 0.002), ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), ("features", torch.nn.Parameter(features), 0.0075), + ("opacities", torch.nn.Parameter(opacities), 5e-2), ("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale), ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), @@ -822,9 +824,9 @@ def train(self): scheduler.step() # eval the full set - # if step in [i - 1 for i in cfg.eval_steps]: - # self.eval(step) - # self.render_traj(step) + if step in [i - 1 for i in cfg.eval_steps]: + self.eval(step) + self.render_traj(step) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: @@ -990,7 +992,7 @@ def render_traj(self, step: int): # save to video video_dir = f"{cfg.result_dir}/videos" os.makedirs(video_dir, exist_ok=True) - writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=30) + writer = imageio.get_writer(f"{video_dir}/traj_{step}.mp4", fps=24) for canvas in canvas_all: writer.append_data(canvas) writer.close() diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 178ab2423..8c29f7b1e 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -269,7 +269,7 @@ def relocate( sampled_idxs = alive_indices[sampled_idxs] new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"])[sampled_idxs], + scales=torch.exp(params["scales"][:,:3])[sampled_idxs], ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -279,7 +279,7 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs] = torch.log(new_scales) + p[sampled_idxs][:,:3] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p) @@ -288,11 +288,13 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v # update the parameters and the state in the optimizers - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): - v[sampled_idxs] = 0 + if k == "anchor_count" or k == "anchor_opacity": + v[sampled_idxs] = 0 @torch.no_grad() @@ -419,6 +421,9 @@ def grow_anchors( ) feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim] + def inverse_sigmoid(x): + return torch.log(x / (1 - x)) + opacities = inverse_sigmoid(0.1 * torch.ones((anchors.shape[0], 1), dtype=torch.float, device="cuda")) # Initialize new offsets offsets = torch.zeros( (num_new, n_feat_offsets, 3), device=device @@ -435,6 +440,8 @@ def param_fn(name: str, p: Tensor) -> Tensor: p_new = torch.cat([p, feat], dim=0) elif name == "offsets": p_new = torch.cat([p, offsets], dim=0) + elif name == "opacities": + p_new = torch.cat([p, opacities], dim=0) else: raise ValueError(f"Parameter '{name}' not recognized.") return torch.nn.Parameter(p_new) @@ -446,7 +453,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v_new # Update parameters and optimizer states - names = ["anchors", "scales", "quats", "features", "offsets"] + names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # Update the extra running state diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index c3745dadb..10e3ccd34 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -4,6 +4,8 @@ from .base import Strategy from .ops import remove_anchors, grow_anchors +from .ops import relocate, sample_add +import math @dataclass @@ -75,6 +77,7 @@ class ScaffoldStrategy(Strategy): prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 refine_start_iter: int = 1500 + cap_max: int = 1_000_000 max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 @@ -123,7 +126,14 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # put them on the correct device. # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. + + n_max = 51 + binoms = torch.zeros((n_max, n_max)) + for n in range(n_max): + for k in range(n + 1): + binoms[n, k] = math.comb(n, k) state = { + "binoms": binoms, "grad2d": None, "count": None, "anchor_count": None, @@ -161,6 +171,7 @@ def check_sanity( "offsets", "scales", "quats", + "opacities", "opacities_mlp", "colors_mlp", "scale_rot_mlp", @@ -196,15 +207,19 @@ def step_post_backward( packed: bool = False, ): """Callback function to be executed after the `loss.backward()` call.""" + if step >= self.refine_stop_iter: return + # move to the correct device + state["binoms"] = state["binoms"].to(params["anchors"].device) + binoms = state["binoms"] + if step > 500: self._update_state(params, state, info, packed=packed) if step > self.refine_start_iter and step % self.refine_every == 0: # grow GSs - print(f"init anchors: {len(params['anchors'])}") new_anchors = self._anchor_growing(params, optimizers, state) if self.verbose: print( @@ -212,37 +227,39 @@ def step_post_backward( f"Now having {len(params['anchors'])} anchors." ) - # prune anchors - low_opacity_mask = ( - state["anchor_opacity"] < self.prune_opa * state["anchor_count"] - ).squeeze() - anchor_mask = ( - state["anchor_count"] > self.pruning_thresholds * self.refine_every - ) # [N, 1] - is_prune = torch.logical_and(low_opacity_mask, anchor_mask) - - indices = anchor_mask.nonzero(as_tuple=False).squeeze() - # Efficiently set the specified indices to zero - state["anchor_count"].index_fill_(0, indices, 0) - state["anchor_opacity"].index_fill_(0, indices, 0) - - n_prune = is_prune.sum().item() - if n_prune > 0: - names = ["anchors", "scales", "quats", "features", "offsets"] - remove_anchors( - params=params, - optimizers=optimizers, - n_feat_offsets=self.n_feat_offsets, - state=state, - mask=is_prune, - names=names, - ) - - if self.verbose: - print( - f"Step {step}: {n_prune} anchors pruned. " - f"Now having {len(params['anchors'])} anchors." - ) + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) + print(f"Relocated anchors {n_relocated_gs}") + # # prune anchors + # low_opacity_mask = ( + # state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + # ).squeeze() + # anchor_mask = ( + # state["anchor_count"] > self.pruning_thresholds * self.refine_every + # ) # [N, 1] + # is_prune = torch.logical_and(low_opacity_mask, anchor_mask) + # + # indices = anchor_mask.nonzero(as_tuple=False).squeeze() + # # Efficiently set the specified indices to zero + # state["anchor_count"].index_fill_(0, indices, 0) + # state["anchor_opacity"].index_fill_(0, indices, 0) + # + # n_prune = is_prune.sum().item() + # if n_prune > 0: + # names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] + # remove_anchors( + # params=params, + # optimizers=optimizers, + # n_feat_offsets=self.n_feat_offsets, + # state=state, + # mask=is_prune, + # names=names, + # ) + # + # if self.verbose: + # print( + # f"Step {step}: {low_opacity_mask.sum().item()} anchors pruned. " + # f"Now having {len(params['anchors'])} anchors." + # ) torch.cuda.empty_cache() @@ -486,3 +503,46 @@ def coords_to_indices(coords, N): state["grad2d"].index_fill_(0, indices, 0) return n_added_anchors + @torch.no_grad() + def _relocate_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + binoms: torch.Tensor, + state: Dict[str, Any], + ) -> int: + dead_mask = ( + state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + ).squeeze() + n_gs = dead_mask.sum().item() + if n_gs > 0: + relocate( + params=params, + optimizers=optimizers, + state=state, + mask=dead_mask, + binoms=binoms, + min_opacity=self.prune_opa, + ) + return n_gs + + @torch.no_grad() + def _add_new_gs( + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + binoms: torch.Tensor, + ) -> int: + current_n_points = len(params["anchors"]) + n_target = min(self.cap_max, int(1.05 * current_n_points)) + n_gs = max(0, n_target - current_n_points) + if n_gs > 0: + sample_add( + params=params, + optimizers=optimizers, + state={}, + n=n_gs, + binoms=binoms, + min_opacity=self.min_opacity, + ) + return n_gs From 74279acd1db73a569d42540803a882ec1284b80f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Tue, 24 Sep 2024 23:08:04 +0200 Subject: [PATCH 08/29] black --- examples/simple_trainer_scaffold.py | 2 +- gsplat/strategy/ops.py | 9 ++++++--- gsplat/strategy/scaffold.py | 19 ++++++++++--------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 4150cb7f4..bb56bb61c 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -55,7 +55,7 @@ class Config: # Path to the Mip-NeRF 360 dataset data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" - #data_dir: str = "/home/paja/data/bike_aliked" + # data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset data_factor: int = 2 # Directory to save results diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 8c29f7b1e..262c9075f 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -269,7 +269,7 @@ def relocate( sampled_idxs = alive_indices[sampled_idxs] new_opacities, new_scales = compute_relocation( opacities=opacities[sampled_idxs], - scales=torch.exp(params["scales"][:,:3])[sampled_idxs], + scales=torch.exp(params["scales"][:, :3])[sampled_idxs], ratios=torch.bincount(sampled_idxs)[sampled_idxs] + 1, binoms=binoms, ) @@ -279,7 +279,7 @@ def param_fn(name: str, p: Tensor) -> Tensor: if name == "opacities": p[sampled_idxs] = torch.logit(new_opacities) elif name == "scales": - p[sampled_idxs][:,:3] = torch.log(new_scales) + p[sampled_idxs][:, :3] = torch.log(new_scales) p[dead_indices] = p[sampled_idxs] return torch.nn.Parameter(p) @@ -423,7 +423,10 @@ def grow_anchors( def inverse_sigmoid(x): return torch.log(x / (1 - x)) - opacities = inverse_sigmoid(0.1 * torch.ones((anchors.shape[0], 1), dtype=torch.float, device="cuda")) + + opacities = inverse_sigmoid( + 0.1 * torch.ones((anchors.shape[0], 1), dtype=torch.float, device="cuda") + ) # Initialize new offsets offsets = torch.zeros( (num_new, n_feat_offsets, 3), device=device diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 10e3ccd34..27036ef9a 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -503,13 +503,14 @@ def coords_to_indices(coords, N): state["grad2d"].index_fill_(0, indices, 0) return n_added_anchors + @torch.no_grad() def _relocate_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - binoms: torch.Tensor, - state: Dict[str, Any], + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + binoms: torch.Tensor, + state: Dict[str, Any], ) -> int: dead_mask = ( state["anchor_opacity"] < self.prune_opa * state["anchor_count"] @@ -528,10 +529,10 @@ def _relocate_gs( @torch.no_grad() def _add_new_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - binoms: torch.Tensor, + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + binoms: torch.Tensor, ) -> int: current_n_points = len(params["anchors"]) n_target = min(self.cap_max, int(1.05 * current_n_points)) From 94ace6ffd2ed6bdbaaa47278cfb88551a1054ae2 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 25 Sep 2024 09:03:12 +0200 Subject: [PATCH 09/29] fix foglet floaters --- examples/simple_trainer_scaffold.py | 11 ++++++----- gsplat/strategy/scaffold.py | 7 +++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index bb56bb61c..f5ed936f8 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,10 +54,10 @@ class Config: render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" + data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" # data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset - data_factor: int = 2 + data_factor: int = 1 # Directory to save results result_dir: str = "results" # Every N images there is a test image @@ -430,11 +430,9 @@ def get_neural_gaussians( length = view_dir.norm(dim=1, keepdim=True) view_dir_normalized = view_dir / length # [M, 3] - view_length = torch.cat([view_dir_normalized, length], dim=1) - # See formula (9) and the appendix for the rest feature_view_dir = torch.cat( - [selected_features, view_length], dim=1 + [selected_features, view_dir_normalized], dim=1 ) # [M, c+3] k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor @@ -579,6 +577,9 @@ def train(self): torch.optim.lr_scheduler.ExponentialLR( self.optimizers["colors_mlp"], gamma=0.00625 ** (1.0 / max_steps) ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["scale_rot_mlp"], gamma=1.0 + ), ] if cfg.pose_opt: # pose optimization has a learning rate schedule diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 27036ef9a..a0d1eb287 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -95,23 +95,22 @@ class ScaffoldStrategy(Strategy): revised_opacity: bool = False verbose: bool = True - view_distance = 1 colors_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3 + 1, feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 3 * n_feat_offsets), torch.nn.Sigmoid(), ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3 + 1, feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh(), ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3 + 1, feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 7 * n_feat_offsets), ).cuda() From 48641772cc4d9d313e282926fa105cb1425ef618 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 25 Sep 2024 11:49:21 +0200 Subject: [PATCH 10/29] better rendering --- examples/simple_trainer_scaffold.py | 27 +++++++++++++++++---- gsplat/strategy/ops.py | 4 ++-- gsplat/strategy/scaffold.py | 37 +++++++---------------------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index f5ed936f8..35b1a3853 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -159,8 +159,6 @@ def adjust_steps(self, factor: float): strategy.refine_start_iter = int(strategy.refine_start_iter * factor) strategy.refine_stop_iter = int(strategy.refine_stop_iter * factor) strategy.voxel_size = self.voxel_size - # strategy.reset_every = int(strategy.reset_every * factor) - # strategy.refine_every = int(strategy.refine_every * factor) def create_splats_with_optimizers( @@ -201,7 +199,7 @@ def create_splats_with_optimizers( opacities = torch.logit(torch.full((N, 1), init_opacity)) # [N,] params = [ # name, value, lr - ("anchors", torch.nn.Parameter(points), 0), + ("anchors", torch.nn.Parameter(points), 0.0001), ("scales", torch.nn.Parameter(scales), 0.007), ("quats", torch.nn.Parameter(quats), 0.002), ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), @@ -419,6 +417,7 @@ def get_neural_gaussians( selected_features = self.splats["features"][visible_anchor_mask] # [M, c] selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3] selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3] + selected_quats = self.splats["quats"][visible_anchor_mask] # [M, 4] selected_scales = torch.exp( self.splats["scales"][visible_anchor_mask] ) # [M, 6] @@ -458,6 +457,9 @@ def get_neural_gaussians( anchors_repeated = ( selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) ) # [M*k, 3] + quats_repeated = ( + selected_quats.unsqueeze(1).repeat(1, k, 1).view(-1, 4) + ) # [M*k, 3] # Apply positive opacity mask selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [M] @@ -466,12 +468,29 @@ def get_neural_gaussians( selected_offsets = selected_offsets[neural_selection_mask] # [M, 3] scales_repeated = scales_repeated[neural_selection_mask] # [M, 6] anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3] + quats_repeated = quats_repeated[neural_selection_mask] # [M, 3] + # Compute scales and rotations scales = scales_repeated[:, 3:] * torch.sigmoid( selected_scale_rot[:, :3] ) # [M, 3] - rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [M, 4] + def quaternion_multiply(q1, q2): + # Extract individual components of the quaternions + w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] + w2, x2, y2, z2 = q2[:, 0], q2[:, 1], q2[:, 2], q2[:, 3] + + # Perform the quaternion multiplication + w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2 + x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2 + y = w1 * y2 + y1 * w2 + z1 * x2 - x1 * z2 + z = w1 * z2 + z1 * w2 + x1 * y2 - y1 * x2 + + # Stack the result back into shape (N, 4) + return torch.stack((w, x, y, z), dim=-1) + + # The rasterizer takes care of the normalization + rotation = quaternion_multiply(quats_repeated,selected_scale_rot[:, 3:7]) # Compute offsets and anchors offsets = selected_offsets * scales_repeated[:, :3] # [M, 3] diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 262c9075f..9a928cc5e 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -461,7 +461,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: # Update the extra running state for k, v in state.items(): - if isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor) and k != "binoms": if k == "anchor_count" or k == "anchor_opacity": zeros = torch.zeros((num_new, *v.shape[1:]), device=device) else: @@ -502,7 +502,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # update the extra running state for k, v in state.items(): - if isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor) and k != "binoms": if k in ["anchor_count", "anchor_opacity"]: state[k] = v[sel] else: diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index a0d1eb287..1d34f0b27 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -226,9 +226,7 @@ def step_post_backward( f"Now having {len(params['anchors'])} anchors." ) - n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) - print(f"Relocated anchors {n_relocated_gs}") - # # prune anchors + #if step % 1000 == 0: # low_opacity_mask = ( # state["anchor_opacity"] < self.prune_opa * state["anchor_count"] # ).squeeze() @@ -237,7 +235,7 @@ def step_post_backward( # ) # [N, 1] # is_prune = torch.logical_and(low_opacity_mask, anchor_mask) # - # indices = anchor_mask.nonzero(as_tuple=False).squeeze() + # indices = is_prune.nonzero(as_tuple=False).squeeze() # # Efficiently set the specified indices to zero # state["anchor_count"].index_fill_(0, indices, 0) # state["anchor_opacity"].index_fill_(0, indices, 0) @@ -256,9 +254,13 @@ def step_post_backward( # # if self.verbose: # print( - # f"Step {step}: {low_opacity_mask.sum().item()} anchors pruned. " + # f"Step {step}: {n_prune} anchors pruned. " # f"Now having {len(params['anchors'])} anchors." # ) + # else: + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) + if self.verbose: + print(f"Relocated anchors {n_relocated_gs}") torch.cuda.empty_cache() @@ -399,7 +401,7 @@ def _anchor_growing( gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) # Drop-out: Helps prevent too massive anchor growth. - rand_mask = torch.rand_like(gradient_mask.float()) > (0.4**m) + rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) rand_mask = rand_mask.cuda() gradient_mask = torch.logical_and(gradient_mask, rand_mask) gradient_mask = torch.cat( @@ -524,25 +526,4 @@ def _relocate_gs( binoms=binoms, min_opacity=self.prune_opa, ) - return n_gs - - @torch.no_grad() - def _add_new_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - binoms: torch.Tensor, - ) -> int: - current_n_points = len(params["anchors"]) - n_target = min(self.cap_max, int(1.05 * current_n_points)) - n_gs = max(0, n_target - current_n_points) - if n_gs > 0: - sample_add( - params=params, - optimizers=optimizers, - state={}, - n=n_gs, - binoms=binoms, - min_opacity=self.min_opacity, - ) - return n_gs + return n_gs \ No newline at end of file From 1b663c54bd946da74533632b2b204a723a28ee07 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 25 Sep 2024 12:17:16 +0200 Subject: [PATCH 11/29] Makes no sense that cov and opacity are view dependent. --- examples/simple_trainer_scaffold.py | 4 ++-- gsplat/strategy/scaffold.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 35b1a3853..8b6bd790c 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -437,7 +437,7 @@ def get_neural_gaussians( k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor # Apply MLPs (they output per-offset features concatenated along the last dimension) - neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] + neural_opacity = self.cfg.strategy.opacities_mlp(selected_features) # [M, k*1] neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k] @@ -446,7 +446,7 @@ def get_neural_gaussians( neural_colors = neural_colors.view(-1, 3) # [M*k, 3] # Get scale and rotation and reshape - neural_scale_rot = self.cfg.strategy.scale_rot_mlp(feature_view_dir) # [M, k*7] + neural_scale_rot = self.cfg.strategy.scale_rot_mlp(selected_features) # [M, k*7] neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] # Reshape selected_offsets, scales, and anchors diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 1d34f0b27..65338e1af 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -103,14 +103,14 @@ class ScaffoldStrategy(Strategy): ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh(), ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 7 * n_feat_offsets), ).cuda() From 0f7a572980ad2affbc10764eb932ba68f8ddbbc7 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 25 Sep 2024 12:38:29 +0200 Subject: [PATCH 12/29] Paper reports 0.001 scale reg factor. --- examples/simple_trainer_scaffold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 8b6bd790c..8497f115e 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -113,7 +113,7 @@ class Config: random_bkgd: bool = False # Scale regularization - scale_reg: float = 0.01 + scale_reg: float = 0.001 # Enable camera optimization. pose_opt: bool = False From e00d134d861cd3ca5916540b2b4176c44345bae9 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Wed, 25 Sep 2024 15:00:11 +0200 Subject: [PATCH 13/29] sota? --- examples/simple_trainer_scaffold.py | 27 +++++++++++++++++------ gsplat/strategy/scaffold.py | 34 ++++++++++++++++++----------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 8497f115e..f0c6f73da 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,10 +54,11 @@ class Config: render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" - # data_dir: str = "/home/paja/data/bike_aliked" + #data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" + data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden/" + #data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset - data_factor: int = 1 + data_factor: int = 4 # Directory to save results result_dir: str = "results" # Every N images there is a test image @@ -207,7 +208,7 @@ def create_splats_with_optimizers( ("opacities", torch.nn.Parameter(opacities), 5e-2), ("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale), ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), - ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), + ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.0004), ] splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) @@ -437,7 +438,7 @@ def get_neural_gaussians( k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor # Apply MLPs (they output per-offset features concatenated along the last dimension) - neural_opacity = self.cfg.strategy.opacities_mlp(selected_features) # [M, k*1] + neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k] @@ -446,7 +447,7 @@ def get_neural_gaussians( neural_colors = neural_colors.view(-1, 3) # [M*k, 3] # Get scale and rotation and reshape - neural_scale_rot = self.cfg.strategy.scale_rot_mlp(selected_features) # [M, k*7] + neural_scale_rot = self.cfg.strategy.scale_rot_mlp(feature_view_dir) # [M, k*7] neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] # Reshape selected_offsets, scales, and anchors @@ -720,6 +721,18 @@ def train(self): loss = l1loss * (1.0 - cfg.ssim_lambda) loss += ssimloss * cfg.ssim_lambda loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg + + # Apply sigmoid to normalize values to [0, 1] + # sigmoid_opacities = torch.sigmoid(info["opacities"]) + # + # # Custom loss to penalize values not close to 0 or 1 + # def binarization_loss(x): + # return (x * (1 - x)).mean() + # + # # Calculate the binarization loss + # opa_loss = binarization_loss(sigmoid_opacities) + # loss += 0.01 * opa_loss + if cfg.depth_loss: # query depths from depth map points = torch.stack( @@ -846,7 +859,7 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step) - self.render_traj(step) + # self.render_traj(step) # run compression if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 65338e1af..43b6cec05 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -4,7 +4,7 @@ from .base import Strategy from .ops import remove_anchors, grow_anchors -from .ops import relocate, sample_add +from .ops import relocate import math @@ -49,8 +49,6 @@ class ScaffoldStrategy(Strategy): reset, Default is 0 (no pause at all) and one might want to set this number to the number of images in training set. absgrad (bool): Use absolute gradients for GS splitting. Default is False. - revised_opacity (bool): Whether to use revised opacity heuristic from - arXiv:2404.06109 (experimental). Default is False. verbose (bool): Whether to print verbose information. Default is False. Examples: @@ -71,17 +69,16 @@ class ScaffoldStrategy(Strategy): """ prune_opa: float = 0.005 - grow_grad2d: float = 0.0002 + grow_grad2d: float = 1.28e-4 grow_scale3d: float = 0.01 grow_scale2d: float = 0.05 prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 - refine_start_iter: int = 1500 - cap_max: int = 1_000_000 + refine_start_iter: int = 800 max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 - feat_dim: int = 32 + feat_dim: int = 128 n_feat_offsets: int = 10 refine_every: int = 100 @@ -92,7 +89,6 @@ class ScaffoldStrategy(Strategy): pause_refine_after_reset: int = 0 absgrad: bool = False - revised_opacity: bool = False verbose: bool = True colors_mlp: torch.nn.Sequential = torch.nn.Sequential( @@ -103,14 +99,14 @@ class ScaffoldStrategy(Strategy): ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim, feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh(), ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim, feat_dim), + torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 7 * n_feat_offsets), ).cuda() @@ -262,6 +258,18 @@ def step_post_backward( if self.verbose: print(f"Relocated anchors {n_relocated_gs}") + # def op_sigmoid(x, k=100, x0=0.995): + # return 1 / (1 + torch.exp(-k * (x - x0))) + # + # opacities = torch.sigmoid(params["opacities"]) + # noise = ( + # torch.randn_like(params["offsets"]) + # * (op_sigmoid(1 - opacities)).unsqueeze(-1) + # + # * 5e5 * 0.00001 + # ) + # + # params["offsets"] = params["offsets"] + noise torch.cuda.empty_cache() def _update_state( @@ -401,9 +409,9 @@ def _anchor_growing( gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) # Drop-out: Helps prevent too massive anchor growth. - rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) - rand_mask = rand_mask.cuda() - gradient_mask = torch.logical_and(gradient_mask, rand_mask) + # rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) + # rand_mask = rand_mask.cuda() + # gradient_mask = torch.logical_and(gradient_mask, rand_mask) gradient_mask = torch.cat( [ gradient_mask, From 43e676f5ff81696eed55f839e08b59438b62058f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 26 Sep 2024 11:44:31 +0200 Subject: [PATCH 14/29] fix crash --- gsplat/strategy/ops.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 9a928cc5e..caf284d5f 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -409,9 +409,7 @@ def grow_anchors( # Prepare new features existing_features = params["features"] # [N_existing, feat_dim] - repeated_features = existing_features.repeat_interleave( - n_feat_offsets, dim=0 - ) # [N_existing * n_feat_offsets, feat_dim] + repeated_features = existing_features.unsqueeze(1).expand(-1, n_feat_offsets, -1).reshape(-1, existing_features.shape[1]) # [N_existing * n_feat_offsets, feat_dim] selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] From b816d5d1544d13585d8343dbd2dc9d9f792ff925 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 26 Sep 2024 13:53:54 +0200 Subject: [PATCH 15/29] black and absgrad --- examples/simple_trainer_scaffold.py | 8 ++++---- gsplat/strategy/ops.py | 6 +++++- gsplat/strategy/scaffold.py | 5 +++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index f0c6f73da..f71e177ab 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,9 +54,9 @@ class Config: render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset + data_dir: str = "examples/data/360_v2/garden" #data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" - data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden/" - #data_dir: str = "/home/paja/data/bike_aliked" + # data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset data_factor: int = 4 # Directory to save results @@ -471,11 +471,11 @@ def get_neural_gaussians( anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3] quats_repeated = quats_repeated[neural_selection_mask] # [M, 3] - # Compute scales and rotations scales = scales_repeated[:, 3:] * torch.sigmoid( selected_scale_rot[:, :3] ) # [M, 3] + def quaternion_multiply(q1, q2): # Extract individual components of the quaternions w1, x1, y1, z1 = q1[:, 0], q1[:, 1], q1[:, 2], q1[:, 3] @@ -491,7 +491,7 @@ def quaternion_multiply(q1, q2): return torch.stack((w, x, y, z), dim=-1) # The rasterizer takes care of the normalization - rotation = quaternion_multiply(quats_repeated,selected_scale_rot[:, 3:7]) + rotation = quaternion_multiply(quats_repeated, selected_scale_rot[:, 3:7]) # Compute offsets and anchors offsets = selected_offsets * scales_repeated[:, :3] # [M, 3] diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index caf284d5f..7c55bf83d 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -409,7 +409,11 @@ def grow_anchors( # Prepare new features existing_features = params["features"] # [N_existing, feat_dim] - repeated_features = existing_features.unsqueeze(1).expand(-1, n_feat_offsets, -1).reshape(-1, existing_features.shape[1]) # [N_existing * n_feat_offsets, feat_dim] + repeated_features = ( + existing_features.unsqueeze(1) + .expand(-1, n_feat_offsets, -1) + .reshape(-1, existing_features.shape[1]) + ) # [N_existing * n_feat_offsets, feat_dim] selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 43b6cec05..859cc440e 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -75,6 +75,7 @@ class ScaffoldStrategy(Strategy): prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 refine_start_iter: int = 800 + absgrad: bool = False max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 @@ -222,7 +223,7 @@ def step_post_backward( f"Now having {len(params['anchors'])} anchors." ) - #if step % 1000 == 0: + # if step % 1000 == 0: # low_opacity_mask = ( # state["anchor_opacity"] < self.prune_opa * state["anchor_count"] # ).squeeze() @@ -534,4 +535,4 @@ def _relocate_gs( binoms=binoms, min_opacity=self.prune_opa, ) - return n_gs \ No newline at end of file + return n_gs From 556d7299a041a41a1b3e3d7e6819c90bef7c54c4 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 26 Sep 2024 16:20:08 +0200 Subject: [PATCH 16/29] restructure according to Ruilong's suggestions --- examples/simple_trainer_scaffold.py | 247 ++++++++++++++++++++-------- gsplat/strategy/ops.py | 6 +- gsplat/strategy/scaffold.py | 43 ++--- 3 files changed, 187 insertions(+), 109 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index f71e177ab..584c75c3d 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -15,6 +15,9 @@ import tyro import viser import yaml +from torch.nn import ModuleDict, ParameterDict +from torch.optim import SparseAdam, Adam + from datasets.colmap import Dataset, Parser from datasets.traj import ( generate_interpolated_path, @@ -54,8 +57,8 @@ class Config: render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset - data_dir: str = "examples/data/360_v2/garden" - #data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" + data_dir: str = "examples/data/360_v2/room" + # data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" # data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset data_factor: int = 4 @@ -69,6 +72,10 @@ class Config: global_scale: float = 1.0 # Normalize the world space normalize_world_space: bool = True + # Dimensionality of anchor features + feat_dim: int = 128 + # Number offsets + n_feat_offsets: int = 10 # Port for the viewer server port: int = 8080 @@ -164,12 +171,10 @@ def adjust_steps(self, factor: float): def create_splats_with_optimizers( parser: Parser, - strategy: ScaffoldStrategy, init_extent: float = 3.0, init_opacity: float = 0.1, init_scale: float = 1.0, scene_scale: float = 1.0, - sh_degree: int = 3, sparse_grad: bool = False, batch_size: int = 1, feature_dim: Optional[int] = None, @@ -177,7 +182,9 @@ def create_splats_with_optimizers( world_rank: int = 0, world_size: int = 1, voxel_size: float = 0.001, -) -> Tuple[torch.nn.ParameterDict, Dict[str, torch.optim.Optimizer]]: +) -> tuple[ + dict[str, ModuleDict | ParameterDict], dict[str, dict[str, SparseAdam | Adam]] +]: # Compare GS-Scaffold paper formula (4) points = np.unique(np.round(parser.points / voxel_size), axis=0) * voxel_size @@ -194,38 +201,102 @@ def create_splats_with_optimizers( N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] - features = torch.zeros((N, strategy.feat_dim)) - offsets = torch.zeros((N, strategy.n_feat_offsets, 3)) + features = torch.zeros((N, cfg.feat_dim)) + offsets = torch.zeros((N, cfg.n_feat_offsets, 3)) opacities = torch.logit(torch.full((N, 1), init_opacity)) # [N,] - params = [ - # name, value, lr - ("anchors", torch.nn.Parameter(points), 0.0001), - ("scales", torch.nn.Parameter(scales), 0.007), - ("quats", torch.nn.Parameter(quats), 0.002), - ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), - ("features", torch.nn.Parameter(features), 0.0075), - ("opacities", torch.nn.Parameter(opacities), 5e-2), - ("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale), - ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), - ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.0004), - ] - - splats = torch.nn.ParameterDict({n: v for n, v, _ in params}).to(device) - # Scale learning rate based on batch size, reference: - # https://www.cs.princeton.edu/~smalladi/blog/2024/01/22/SDEs-ScalingRules/ - # Note that this would not make the training exactly equivalent, see - # https://arxiv.org/pdf/2402.18824v1 + + # Define learning rates for gauss_params and decoders + learning_rates = { + "anchors": 0.0001, + "scales": 0.007, + "quats": 0.002, + "features": 0.0075, + "opacities": 5e-2, + "offsets": 0.01 * scene_scale, + "opacities_mlp": 0.002, + "colors_mlp": 0.008, + "scale_rot_mlp": 0.0004, + } + + # Define gauss_params + gauss_params = torch.nn.ParameterDict( + { + "anchors": torch.nn.Parameter(points), + "scales": torch.nn.Parameter(scales), + "quats": torch.nn.Parameter(quats), + "features": torch.nn.Parameter(features), + "opacities": torch.nn.Parameter(opacities), + "offsets": torch.nn.Parameter(offsets), + } + ).to(device) + + # Define the MLPs (decoders) + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, 3 * cfg.n_feat_offsets), + torch.nn.Sigmoid(), + ).cuda() + + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, cfg.n_feat_offsets), + torch.nn.Tanh(), + ).cuda() + + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( + torch.nn.Linear(cfg.feat_dim + 3, cfg.feat_dim), + torch.nn.ReLU(True), + torch.nn.Linear(cfg.feat_dim, 7 * cfg.n_feat_offsets), + ).cuda() + + # Initialize decoders (MLPs) + decoders = torch.nn.ModuleDict( + { + "opacities_mlp": opacities_mlp, + "colors_mlp": colors_mlp, + "scale_rot_mlp": scale_rot_mlp, + } + ).to(device) + + # Scale learning rates based on batch size (BS) BS = batch_size * world_size - optimizers = { + + # Create optimizers for gauss_params + gauss_optimizers = { name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( - [{"params": splats[name], "lr": lr * math.sqrt(BS), "name": name}], + [{"params": param, "lr": learning_rates[name] * math.sqrt(BS)}], eps=1e-15 / math.sqrt(BS), - # TODO: check betas logic when BS is larger than 10 betas[0] will be zero. betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), ) - for name, _, lr in params + for name, param in gauss_params.items() } + + # Create optimizers for decoders + decoders_optimizers = { + name: (torch.optim.SparseAdam if sparse_grad else torch.optim.Adam)( + [ + { + "params": decoder.parameters(), + "lr": learning_rates[name] * math.sqrt(BS), + } + ], + eps=1e-15 / math.sqrt(BS), + betas=(1 - BS * (1 - 0.9), 1 - BS * (1 - 0.999)), + ) + for name, decoder in decoders.items() + } + + # Combine gauss_params and decoders optimizers into a dictionary of dictionaries + optimizers = { + "gauss_optimizer": gauss_optimizers, + "decoders_optimizer": decoders_optimizers, + } + + # Return the gauss_params, decoders, and the dictionary of dictionaries for optimizers + splats = {"gauss_params": gauss_params, "decoders": decoders} return splats, optimizers @@ -279,7 +350,6 @@ def __init__( feature_dim = 32 if cfg.app_opt else None self.splats, self.optimizers = create_splats_with_optimizers( self.parser, - strategy=self.cfg.strategy, init_extent=cfg.init_extent, init_opacity=cfg.init_opa, init_scale=cfg.init_scale, @@ -292,13 +362,20 @@ def __init__( world_size=world_size, voxel_size=cfg.voxel_size, ) - print("Model initialized. Number of GS:", len(self.splats["anchors"])) + print( + "Model initialized. Number of GS:", + len(self.splats["gauss_params"]["anchors"]), + ) # Densification Strategy - self.cfg.strategy.check_sanity(self.splats, self.optimizers) + self.cfg.strategy.check_sanity( + self.splats["gauss_params"], self.optimizers["gauss_optimizer"] + ) self.strategy_state = self.cfg.strategy.initialize_state( - scene_scale=self.scene_scale + scene_scale=self.scene_scale, + feat_dim=cfg.feat_dim, + n_feat_offsets=cfg.n_feat_offsets, ) # Compression Strategy self.compression_method = None @@ -404,9 +481,9 @@ def get_neural_gaussians( # Compare paper: Helps mainly to speed up the rasterization. Has no quality impact visible_anchor_mask = view_to_visible_anchors( - means=self.splats["anchors"], - quats=self.splats["quats"], - scales=torch.exp(self.splats["scales"])[:, :3], + means=self.splats["gauss_params"]["anchors"], + quats=self.splats["gauss_params"]["quats"], + scales=torch.exp(self.splats["gauss_params"]["scales"][:, :3]), viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] width=width, @@ -415,12 +492,20 @@ def get_neural_gaussians( rasterize_mode=rasterize_mode, ) # If no visibility mask is provided, we select all anchors including their offsets - selected_features = self.splats["features"][visible_anchor_mask] # [M, c] - selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3] - selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3] - selected_quats = self.splats["quats"][visible_anchor_mask] # [M, 4] + selected_features = self.splats["gauss_params"]["features"][ + visible_anchor_mask + ] # [M, c] + selected_anchors = self.splats["gauss_params"]["anchors"][ + visible_anchor_mask + ] # [M, 3] + selected_offsets = self.splats["gauss_params"]["offsets"][ + visible_anchor_mask + ] # [M, k, 3] + selected_quats = self.splats["gauss_params"]["quats"][ + visible_anchor_mask + ] # [M, 4] selected_scales = torch.exp( - self.splats["scales"][visible_anchor_mask] + self.splats["gauss_params"]["scales"][visible_anchor_mask] ) # [M, 6] # See formula (5) in Scaffold-GS @@ -435,19 +520,25 @@ def get_neural_gaussians( [selected_features, view_dir_normalized], dim=1 ) # [M, c+3] - k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor + k = self.cfg.n_feat_offsets # Number of offsets per anchor # Apply MLPs (they output per-offset features concatenated along the last dimension) - neural_opacity = self.cfg.strategy.opacities_mlp(feature_view_dir) # [M, k*1] + neural_opacity = self.splats["decoders"]["opacities_mlp"]( + feature_view_dir + ) # [M, k*1] neural_opacity = neural_opacity.view(-1, 1) # [M*k, 1] neural_selection_mask = (neural_opacity > 0.0).view(-1) # [M*k] # Get color and reshape - neural_colors = self.cfg.strategy.colors_mlp(feature_view_dir) # [M, k*3] + neural_colors = self.splats["decoders"]["colors_mlp"]( + feature_view_dir + ) # [M, k*3] neural_colors = neural_colors.view(-1, 3) # [M*k, 3] # Get scale and rotation and reshape - neural_scale_rot = self.cfg.strategy.scale_rot_mlp(feature_view_dir) # [M, k*7] + neural_scale_rot = self.splats["decoders"]["scale_rot_mlp"]( + feature_view_dir + ) # [M, k*7] neural_scale_rot = neural_scale_rot.view(-1, 7) # [M*k, 7] # Reshape selected_offsets, scales, and anchors @@ -499,7 +590,7 @@ def quaternion_multiply(q1, q2): v_a = ( visible_anchor_mask.unsqueeze(dim=1) - .repeat([1, self.cfg.strategy.n_feat_offsets]) + .repeat([1, self.cfg.n_feat_offsets]) .view(-1) ) all_neural_gaussians = torch.zeros_like(v_a, dtype=torch.bool) @@ -539,7 +630,7 @@ def rasterize_splats( image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( - features=self.splats["features"], + features=self.splats["gauss_params"]["features"], embed_ids=image_ids, dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3], ) @@ -585,20 +676,23 @@ def train(self): schedulers = [ torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["anchors"], gamma=0.001 ** (1.0 / max_steps) + self.optimizers["gauss_optimizer"]["anchors"], + gamma=0.001 ** (1.0 / max_steps), ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["offsets"], + self.optimizers["gauss_optimizer"]["offsets"], gamma=(0.01 * self.scene_scale) ** (1.0 / max_steps), ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["opacities_mlp"], gamma=0.001 ** (1.0 / max_steps) + self.optimizers["decoders_optimizer"]["opacities_mlp"], + gamma=0.001 ** (1.0 / max_steps), ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["colors_mlp"], gamma=0.00625 ** (1.0 / max_steps) + self.optimizers["decoders_optimizer"]["colors_mlp"], + gamma=0.00625 ** (1.0 / max_steps), ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["scale_rot_mlp"], gamma=1.0 + self.optimizers["decoders_optimizer"]["scale_rot_mlp"], gamma=1.0 ), ] if cfg.pose_opt: @@ -639,9 +733,9 @@ def train(self): global_tic = time.time() pbar = tqdm.tqdm(range(init_step, max_steps)) - self.cfg.strategy.scale_rot_mlp.train() - self.cfg.strategy.opacities_mlp.train() - self.cfg.strategy.colors_mlp.train() + self.splats["decoders"]["scale_rot_mlp"].train() + self.splats["decoders"]["opacities_mlp"].train() + self.splats["decoders"]["colors_mlp"].train() for step in pbar: if not cfg.disable_viewer: @@ -706,8 +800,8 @@ def train(self): colors = colors + bkgd * (1.0 - alphas) self.cfg.strategy.step_pre_backward( - params=self.splats, - optimizers=self.optimizers, + params=self.splats["gauss_params"], + optimizers=self.optimizers["gauss_optimizer"], state=self.strategy_state, step=step, info=info, @@ -773,7 +867,7 @@ def train(self): self.writer.add_scalar("train/l1loss", l1loss.item(), step) self.writer.add_scalar("train/ssimloss", ssimloss.item(), step) self.writer.add_scalar( - "train/num_GS", len(self.splats["anchors"]), step + "train/num_GS", len(self.splats["gauss_params"]["anchors"]), step ) self.writer.add_scalar("train/mem", mem, step) if cfg.depth_loss: @@ -792,7 +886,7 @@ def train(self): stats = { "mem": mem, "ellipse_time": time.time() - global_tic, - "num_GS": len(self.splats["anchors"]), + "num_GS": len(self.splats["gauss_params"]["anchors"]), } print("Step: ", step, stats) with open( @@ -800,7 +894,10 @@ def train(self): "w", ) as f: json.dump(stats, f) - data = {"step": step, "splats": self.splats.state_dict()} + data = { + "step": step, + "splats": self.splats["gauss_params"].state_dict(), + } if cfg.pose_opt: if world_size > 1: data["pose_adjust"] = self.pose_adjust.module.state_dict() @@ -817,8 +914,8 @@ def train(self): # For now no post steps self.cfg.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, + params=self.splats["gauss_params"], + optimizers=self.optimizers["gauss_optimizer"], state=self.strategy_state, step=step, info=info, @@ -829,19 +926,23 @@ def train(self): if cfg.sparse_grad: assert cfg.packed, "Sparse gradients only work with packed mode." gaussian_ids = info["gaussian_ids"] - for k in self.splats.keys(): - grad = self.splats[k].grad + for k in self.splats["gauss_params"].keys(): + grad = self.splats["gauss_params"][k].grad if grad is None or grad.is_sparse: continue - self.splats[k].grad = torch.sparse_coo_tensor( + self.splats["gauss_params"][k].grad = torch.sparse_coo_tensor( indices=gaussian_ids[None], # [1, nnz] values=grad[gaussian_ids], # [nnz, ...] - size=self.splats[k].size(), # [N, ...] + size=self.splats["gauss_params"][k].size(), # [N, ...] is_coalesced=len(Ks) == 1, ) # optimize - for optimizer in self.optimizers.values(): + for optimizer in self.optimizers["gauss_optimizer"].values(): + optimizer.step() + optimizer.zero_grad(set_to_none=True) + # optimize + for optimizer in self.optimizers["decoders_optimizer"].values(): optimizer.step() optimizer.zero_grad(set_to_none=True) for optimizer in self.pose_optimizers: @@ -939,7 +1040,7 @@ def eval(self, step: int, stage: str = "val"): stats.update( { "ellipse_time": ellipse_time, - "num_GS": len(self.splats["anchors"]), + "num_GS": len(self.splats["gauss_params"]["anchors"]), } ) print( @@ -1040,12 +1141,12 @@ def run_compression(self, step: int): compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" os.makedirs(compress_dir, exist_ok=True) - self.compression_method.compress(compress_dir, self.splats) + self.compression_method.compress(compress_dir, self.splats["gauss_params"]) # evaluate compression splats_c = self.compression_method.decompress(compress_dir) for k in splats_c.keys(): - self.splats[k].data = splats_c[k].to(self.device) + self.splats["gauss_params"][k].data = splats_c[k].to(self.device) self.eval(step=step, stage="compress") @torch.no_grad() @@ -1084,8 +1185,10 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): torch.load(file, map_location=runner.device, weights_only=True) for file in cfg.ckpt ] - for k in runner.splats.keys(): - runner.splats[k].data = torch.cat([ckpt["splats"][k] for ckpt in ckpts]) + for k in runner.splats["gauss_params"].keys(): + runner.splats["gauss_params"][k].data = torch.cat( + [ckpt["splats"][k] for ckpt in ckpts] + ) step = ckpts[0]["step"] runner.eval(step=step) runner.render_traj(step=step) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 7c55bf83d..c00b34a67 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -288,8 +288,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v # update the parameters and the state in the optimizers - names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -458,8 +457,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v_new # Update parameters and optimizer states - names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # Update the extra running state for k, v in state.items(): diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 859cc440e..1585299fa 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -79,8 +79,6 @@ class ScaffoldStrategy(Strategy): max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 - feat_dim: int = 128 - n_feat_offsets: int = 10 refine_every: int = 100 # 3.3 Observation Thresholds (compare paper) @@ -92,27 +90,9 @@ class ScaffoldStrategy(Strategy): absgrad: bool = False verbose: bool = True - colors_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), - torch.nn.ReLU(True), - torch.nn.Linear(feat_dim, 3 * n_feat_offsets), - torch.nn.Sigmoid(), - ).cuda() - - opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), - torch.nn.ReLU(True), - torch.nn.Linear(feat_dim, n_feat_offsets), - torch.nn.Tanh(), - ).cuda() - - scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), - torch.nn.ReLU(True), - torch.nn.Linear(feat_dim, 7 * n_feat_offsets), - ).cuda() - - def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: + def initialize_state( + self, scene_scale: float = 1.0, feat_dim=128, n_feat_offsets=10 + ) -> Dict[str, Any]: """Initialize and return the running state for this strategy. The returned state should be passed to the `step_pre_backward()` and @@ -135,6 +115,8 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: "anchor_count": None, "anchor_opacity": None, "scene_scale": scene_scale, + "feat_dim": feat_dim, + "n_feat_offsets": n_feat_offsets, } return state @@ -168,9 +150,6 @@ def check_sanity( "scales", "quats", "opacities", - "opacities_mlp", - "colors_mlp", - "scale_rot_mlp", ] assert len(expected_params) == len( @@ -239,14 +218,12 @@ def step_post_backward( # # n_prune = is_prune.sum().item() # if n_prune > 0: - # names = ["anchors", "scales", "quats", "features", "offsets", "opacities"] # remove_anchors( # params=params, # optimizers=optimizers, # n_feat_offsets=self.n_feat_offsets, # state=state, # mask=is_prune, - # names=names, # ) # # if self.verbose: @@ -302,11 +279,11 @@ def _update_state( device = grads.device if state["grad2d"] is None: state["grad2d"] = torch.zeros( - n_gaussian * self.n_feat_offsets, device=device + n_gaussian * state["n_feat_offsets"], device=device ) if state["count"] is None: state["count"] = torch.zeros( - n_gaussian * self.n_feat_offsets, device=device + n_gaussian * state["n_feat_offsets"], device=device ) if state["anchor_count"] is None: state["anchor_count"] = torch.zeros(n_gaussian, device=device) @@ -337,7 +314,7 @@ def _update_state( neural_opacities = ( info["neural_opacities"] .detach() - .view(-1, self.n_feat_offsets) + .view(-1, state["n_feat_offsets"]) .clamp_min_(0) .sum(dim=1) ) @@ -497,8 +474,8 @@ def coords_to_indices(coords, N): remove_duplicates_mask=remove_occupied_pos_mask, inv_idx=inv_idx, voxel_size=current_voxel_size, - n_feat_offsets=self.n_feat_offsets, - feat_dim=self.feat_dim, + n_feat_offsets=state["n_feat_offsets"], + feat_dim=state["feat_dim"], ) n_added_anchors += new_anchors.shape[0] From 375bd67628edc63d21bb1518bf97e1570491d2e8 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 26 Sep 2024 17:14:38 +0200 Subject: [PATCH 17/29] gihub action seem not to like my black --- gsplat/compression/png_compression.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gsplat/compression/png_compression.py b/gsplat/compression/png_compression.py index 2dbece8cf..ee4ead6f9 100644 --- a/gsplat/compression/png_compression.py +++ b/gsplat/compression/png_compression.py @@ -368,7 +368,9 @@ def _compress_kmeans( maxs = torch.max(centroids) centroids_norm = (centroids - mins) / (maxs - mins) centroids_norm = centroids_norm.detach().cpu().numpy() - centroids_quant = (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) + centroids_quant = ( + (centroids_norm * (2**quantization - 1)).round().astype(np.uint8) + ) labels = labels.astype(np.uint16) npz_dict = { From f9b8d04a17900932dda8d424448b8190b1a8fee3 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 09:23:27 +0200 Subject: [PATCH 18/29] add ckpt save and loading --- examples/benchmarks/scaffold.sh | 53 ++++++++++++++ examples/simple_trainer_scaffold.py | 85 +++++++++------------- gsplat/strategy/scaffold.py | 105 +++++++++++----------------- 3 files changed, 127 insertions(+), 116 deletions(-) create mode 100644 examples/benchmarks/scaffold.sh diff --git a/examples/benchmarks/scaffold.sh b/examples/benchmarks/scaffold.sh new file mode 100644 index 000000000..fc6811ba2 --- /dev/null +++ b/examples/benchmarks/scaffold.sh @@ -0,0 +1,53 @@ +SCENE_DIR="data/360_v2" +RESULT_DIR="results/benchmark" +SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +RENDER_TRAJ_PATH="ellipse" + +for SCENE in $SCENE_LIST; +do + if [ "$SCENE" = "bonsai" ] || [ "$SCENE" = "counter" ] || [ "$SCENE" = "kitchen" ] || [ "$SCENE" = "room" ]; then + DATA_FACTOR=2 + else + DATA_FACTOR=4 + fi + + echo "Running $SCENE" + + # train without eval + CUDA_VISIBLE_DEVICES=0 python simple_trainer_scaffold.py --eval_steps -1 --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ + + # run eval and render + for CKPT in $RESULT_DIR/$SCENE/ckpts/*; + do + CUDA_VISIBLE_DEVICES=0 python simple_trainer_scaffold.py --disable_viewer --data_factor $DATA_FACTOR \ + --render_traj_path $RENDER_TRAJ_PATH \ + --data_dir data/360_v2/$SCENE/ \ + --result_dir $RESULT_DIR/$SCENE/ \ + --ckpt $CKPT + done +done + + +for SCENE in $SCENE_LIST; +do + echo "=== Eval Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/val*.json; + do + echo $STATS + cat $STATS; + echo + done + + echo "=== Train Stats ===" + + for STATS in $RESULT_DIR/$SCENE/stats/train*_rank0.json; + do + echo $STATS + cat $STATS; + echo + done +done \ No newline at end of file diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 584c75c3d..c3fd8b6bf 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass, field from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple import imageio import nerfview @@ -39,7 +39,6 @@ total_variation_loss, ) -from gsplat.compression import PngCompression from gsplat.distributed import cli from gsplat.rendering import rasterization, view_to_visible_anchors from gsplat.strategy import ScaffoldStrategy @@ -50,9 +49,12 @@ class Config: # Disable viewer disable_viewer: bool = False # Path to the .pt files. If provide, it will skip training and run evaluation only. - ckpt: Optional[List[str]] = None - # Name of compression strategy to use - compression: Optional[Literal["png"]] = None + # ckpt: Optional[List[str]] = None + ckpt: Optional[List[str]] = field( + default_factory=lambda: [ + "/home/paja/projects/gsplat_fork/results/ckpts/ckpt_1999_rank0.pt" + ] + ) # Render trajectory path render_traj_path: str = "ellipse" @@ -90,7 +92,7 @@ class Config: # Steps to evaluate the model eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) + save_steps: List[int] = field(default_factory=lambda: [2_000, 30_000]) # voxel size for Scaffold-GS voxel_size = 0.001 @@ -377,13 +379,6 @@ def __init__( feat_dim=cfg.feat_dim, n_feat_offsets=cfg.n_feat_offsets, ) - # Compression Strategy - self.compression_method = None - if cfg.compression is not None: - if cfg.compression == "png": - self.compression_method = PngCompression() - else: - raise ValueError(f"Unknown compression strategy: {cfg.compression}") self.pose_optimizers = [] if cfg.pose_opt: @@ -896,7 +891,16 @@ def train(self): json.dump(stats, f) data = { "step": step, - "splats": self.splats["gauss_params"].state_dict(), + "feat_dim": self.cfg.feat_dim, + "n_feat_offsets": self.cfg.n_feat_offsets, + "gauss_params": self.splats["gauss_params"].state_dict(), + "opacities_mlp": self.splats["decoders"][ + "opacities_mlp" + ].state_dict(), + "colors_mlp": self.splats["decoders"]["colors_mlp"].state_dict(), + "scale_rot_mlp": self.splats["decoders"][ + "scale_rot_mlp" + ].state_dict(), } if cfg.pose_opt: if world_size > 1: @@ -962,10 +966,6 @@ def train(self): self.eval(step) # self.render_traj(step) - # run compression - if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]: - self.run_compression(step=step) - if not cfg.disable_viewer: self.viewer.lock.release() num_train_steps_per_sec = 1.0 / (time.time() - tic) @@ -978,9 +978,16 @@ def train(self): self.viewer.update(step, num_train_rays_per_step) @torch.no_grad() - def eval(self, step: int, stage: str = "val"): + def eval(self, step: int, n_feat_offsets: int, feat_dim: int, stage: str = "val"): """Entry for evaluation.""" print("Running evaluation...") + assert ( + n_feat_offsets == self.cfg.n_feat_offsets + ), f"Feature offset count changed, should be {n_feat_offsets}" + assert ( + feat_dim == self.cfg.feat_dim + ), f"Feature dim changed, should be {feat_dim}" + cfg = self.cfg device = self.device world_rank = self.world_rank @@ -1132,23 +1139,6 @@ def render_traj(self, step: int): writer.close() print(f"Video saved to {video_dir}/traj_{step}.mp4") - @torch.no_grad() - def run_compression(self, step: int): - """Entry for running compression.""" - print("Running compression...") - world_rank = self.world_rank - - compress_dir = f"{cfg.result_dir}/compression/rank{world_rank}" - os.makedirs(compress_dir, exist_ok=True) - - self.compression_method.compress(compress_dir, self.splats["gauss_params"]) - - # evaluate compression - splats_c = self.compression_method.decompress(compress_dir) - for k in splats_c.keys(): - self.splats["gauss_params"][k].data = splats_c[k].to(self.device) - self.eval(step=step, stage="compress") - @torch.no_grad() def _viewer_render_fn( self, camera_state: nerfview.CameraState, img_wh: Tuple[int, int] @@ -1185,15 +1175,18 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): torch.load(file, map_location=runner.device, weights_only=True) for file in cfg.ckpt ] + for k in runner.splats["gauss_params"].keys(): runner.splats["gauss_params"][k].data = torch.cat( - [ckpt["splats"][k] for ckpt in ckpts] + [ckpt["gauss_params"][k] for ckpt in ckpts] ) + for k in runner.splats["decoders"].keys(): + runner.splats["decoders"][k].load_state_dict(ckpts[0][k]) step = ckpts[0]["step"] - runner.eval(step=step) + n_feat_offsets = ckpts[0]["n_feat_offsets"] + feat_dim = ckpts[0]["feat_dim"] + runner.eval(step=step, n_feat_offsets=n_feat_offsets, feat_dim=feat_dim) runner.render_traj(step=step) - if cfg.compression is not None: - runner.run_compression(step=step) else: runner.train() @@ -1218,16 +1211,4 @@ def main(local_rank: int, world_rank, world_size: int, cfg: Config): cfg = tyro.cli(Config) cfg.adjust_steps(cfg.steps_scaler) - # try import extra dependencies - if cfg.compression == "png": - try: - import plas - import torchpq - except: - raise ImportError( - "To use PNG compression, you need to install " - "torchpq (instruction at https://github.com/DeMoriarty/TorchPQ?tab=readme-ov-file#install) " - "and plas (via 'pip install git+https://github.com/fraunhoferhhi/PLAS.git') " - ) - cli(main, cfg, verbose=True) diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 1585299fa..bb2e6b11b 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -34,14 +34,6 @@ class ScaffoldStrategy(Strategy): prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. grow_grad2d (float): GSs with image plane gradient above this value will be split/duplicated. Default is 0.0002. - grow_scale3d (float): GSs with 3d scale (normalized by scene_scale) below this - value will be duplicated. Above will be split. Default is 0.01. - grow_scale2d (float): GSs with 2d scale (normalized by image resolution) above - this value will be split. Default is 0.05. - prune_scale3d (float): GSs with 3d scale (normalized by scene_scale) above this - value will be pruned. Default is 0.1. - prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above - this value will be pruned. Default is 0.15. refine_start_iter (int): Start refining GSs after this iteration. Default is 500. refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. refine_every (int): Refine GSs every this steps. Default is 100. @@ -70,10 +62,6 @@ class ScaffoldStrategy(Strategy): prune_opa: float = 0.005 grow_grad2d: float = 1.28e-4 - grow_scale3d: float = 0.01 - grow_scale2d: float = 0.05 - prune_scale3d: float = 0.1 - prune_scale2d: float = 0.15 refine_start_iter: int = 800 absgrad: bool = False max_voxel_levels: int = 3 @@ -87,8 +75,9 @@ class ScaffoldStrategy(Strategy): growing_thresholds: float = 0.4 pause_refine_after_reset: int = 0 - absgrad: bool = False verbose: bool = True + drop_out: bool = False + pruning: bool = False def initialize_state( self, scene_scale: float = 1.0, feat_dim=128, n_feat_offsets=10 @@ -198,56 +187,42 @@ def step_post_backward( new_anchors = self._anchor_growing(params, optimizers, state) if self.verbose: print( - f"Step {step}: {new_anchors} anchors grown." - f"Now having {len(params['anchors'])} anchors." + f"Step {step}: {new_anchors} anchors grown. Now having {len(params['anchors'])} anchors." ) - # if step % 1000 == 0: - # low_opacity_mask = ( - # state["anchor_opacity"] < self.prune_opa * state["anchor_count"] - # ).squeeze() - # anchor_mask = ( - # state["anchor_count"] > self.pruning_thresholds * self.refine_every - # ) # [N, 1] - # is_prune = torch.logical_and(low_opacity_mask, anchor_mask) - # - # indices = is_prune.nonzero(as_tuple=False).squeeze() - # # Efficiently set the specified indices to zero - # state["anchor_count"].index_fill_(0, indices, 0) - # state["anchor_opacity"].index_fill_(0, indices, 0) - # - # n_prune = is_prune.sum().item() - # if n_prune > 0: - # remove_anchors( - # params=params, - # optimizers=optimizers, - # n_feat_offsets=self.n_feat_offsets, - # state=state, - # mask=is_prune, - # ) - # - # if self.verbose: - # print( - # f"Step {step}: {n_prune} anchors pruned. " - # f"Now having {len(params['anchors'])} anchors." - # ) - # else: - n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) - if self.verbose: - print(f"Relocated anchors {n_relocated_gs}") - - # def op_sigmoid(x, k=100, x0=0.995): - # return 1 / (1 + torch.exp(-k * (x - x0))) - # - # opacities = torch.sigmoid(params["opacities"]) - # noise = ( - # torch.randn_like(params["offsets"]) - # * (op_sigmoid(1 - opacities)).unsqueeze(-1) - # - # * 5e5 * 0.00001 - # ) - # - # params["offsets"] = params["offsets"] + noise + if self.pruning: + low_opacity_mask = ( + state["anchor_opacity"] < self.prune_opa * state["anchor_count"] + ).squeeze() + anchor_mask = ( + state["anchor_count"] > self.pruning_thresholds * self.refine_every + ) # [N, 1] + is_prune = torch.logical_and(low_opacity_mask, anchor_mask) + + indices = is_prune.nonzero(as_tuple=False).squeeze() + # Efficiently set the specified indices to zero + state["anchor_count"].index_fill_(0, indices, 0) + state["anchor_opacity"].index_fill_(0, indices, 0) + + n_prune = is_prune.sum().item() + if n_prune > 0: + remove_anchors( + params=params, + optimizers=optimizers, + n_feat_offsets=state["n_feat_offsets"], + state=state, + mask=is_prune, + ) + + if self.verbose: + print( + f"Step {step}: {n_prune} anchors pruned. Now having {len(params['anchors'])} anchors." + ) + else: + n_relocated_gs = self._relocate_gs(params, optimizers, binoms, state) + if self.verbose: + print(f"Relocated anchors {n_relocated_gs}") + torch.cuda.empty_cache() def _update_state( @@ -387,9 +362,11 @@ def _anchor_growing( gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) # Drop-out: Helps prevent too massive anchor growth. - # rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) - # rand_mask = rand_mask.cuda() - # gradient_mask = torch.logical_and(gradient_mask, rand_mask) + if self.drop_out: + rand_mask = torch.rand_like(gradient_mask.float()) > (0.5**m) + rand_mask = rand_mask.cuda() + gradient_mask = torch.logical_and(gradient_mask, rand_mask) + gradient_mask = torch.cat( [ gradient_mask, From 1d3f7861cee5460bbbc19a4914ef9a66538ca3b6 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 09:33:55 +0200 Subject: [PATCH 19/29] fix docs --- gsplat/__init__.py | 2 +- gsplat/strategy/scaffold.py | 35 +++++++++++++++++------------------ 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/gsplat/__init__.py b/gsplat/__init__.py index df47d1555..03ba23b40 100644 --- a/gsplat/__init__.py +++ b/gsplat/__init__.py @@ -23,7 +23,7 @@ rasterization_inria_wrapper, rasterization_2dgs_inria_wrapper, ) -from .strategy import DefaultStrategy, MCMCStrategy, Strategy +from .strategy import DefaultStrategy, MCMCStrategy, ScaffoldStrategy, Strategy from .version import __version__ all = [ diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index bb2e6b11b..d331172c9 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -16,10 +16,8 @@ class ScaffoldStrategy(Strategy): The strategy will: - - Periodically duplicate GSs with high image plane gradients and small scales. - - Periodically split GSs with high image plane gradients and large scales. - - Periodically prune GSs with low opacity. - - Periodically reset GSs to a lower opacity. + - Periodically grows anchors with high image plane gradients. + - Periodically teleport anchors with low opacity to a place that has high opacity. If `absgrad=True`, it will use the absolute gradients instead of average gradients for GS duplicating & splitting, following the AbsGS paper: @@ -31,24 +29,26 @@ class ScaffoldStrategy(Strategy): with `absgrad=True` as well so that the absolute gradients are computed. Args: - prune_opa (float): GSs with opacity below this value will be pruned. Default is 0.005. - grow_grad2d (float): GSs with image plane gradient above this value will be - split/duplicated. Default is 0.0002. - refine_start_iter (int): Start refining GSs after this iteration. Default is 500. - refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. - refine_every (int): Refine GSs every this steps. Default is 100. - pause_refine_after_reset (int): Pause refining GSs until this number of steps after - reset, Default is 0 (no pause at all) and one might want to set this number to the - number of images in training set. - absgrad (bool): Use absolute gradients for GS splitting. Default is False. - verbose (bool): Whether to print verbose information. Default is False. + prune_opa (float): Threshold for pruning GSs with opacity below this value. Default is 0.005. + grow_grad2d (float): Threshold for splitting/duplicating GSs based on image plane gradient. Default is 0.0002. + refine_start_iter (int): Iteration to start refining GSs. Default is 500. + refine_stop_iter (int): Iteration to stop refining GSs. Default is 15,000. + refine_every (int): Frequency (in steps) at which GSs are refined. Default is 100. + absgrad (bool): Whether to use absolute gradients for GS splitting. Default is False. + verbose (bool): If True, prints detailed information during refinement. Default is False. + max_voxel_levels (int): Maximum levels for voxel splitting during GS growth. Default is 3. + voxel_size (float): Base size of the voxel used in GS growth. Default is 0.001. + pruning_thresholds (float): Threshold for pruning based on refinement steps. Default is 0.8. + growing_thresholds (float): Threshold for GS growth based on refinement steps. Default is 0.4. + drop_out (bool): If True, applies dropout during GS growth to prevent overgrowth. Default is False. + pruning (bool): If True, enables pruning of GSs during refinement. Default is False. Examples: - >>> from gsplat import DefaultStrategy, rasterization + >>> from gsplat import ScaffoldStrategy, rasterization >>> params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ... >>> optimizers: Dict[str, torch.optim.Optimizer] = ... - >>> strategy = DefaultStrategy() + >>> strategy = ScaffoldStrategy() >>> strategy.check_sanity(params, optimizers) >>> strategy_state = strategy.initialize_state() >>> for step in range(1000): @@ -74,7 +74,6 @@ class ScaffoldStrategy(Strategy): pruning_thresholds: float = 0.8 growing_thresholds: float = 0.4 - pause_refine_after_reset: int = 0 verbose: bool = True drop_out: bool = False pruning: bool = False From 60426afabaf7b34fb58800b2960bc6e1ab91de3f Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 09:37:45 +0200 Subject: [PATCH 20/29] cleanup appearance embedding --- examples/simple_trainer_scaffold.py | 56 ++--------------------------- 1 file changed, 2 insertions(+), 54 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index c3fd8b6bf..1f584151f 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -92,7 +92,7 @@ class Config: # Steps to evaluate the model eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # Steps to save the model - save_steps: List[int] = field(default_factory=lambda: [2_000, 30_000]) + save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000]) # voxel size for Scaffold-GS voxel_size = 0.001 @@ -134,15 +134,6 @@ class Config: # Add noise to camera extrinsics. This is only to test the camera pose optimization. pose_noise: float = 0.0 - # Enable appearance optimization. (experimental) - app_opt: bool = False - # Appearance embedding dimension - app_embed_dim: int = 16 - # Learning rate for appearance optimization - app_opt_lr: float = 1e-3 - # Regularization for appearance optimization as weight decay - app_opt_reg: float = 1e-6 - # Enable bilateral grid. (experimental) use_bilateral_grid: bool = False # Shape of the bilateral grid (X, Y, W) @@ -349,7 +340,6 @@ def __init__( print("Scene scale:", self.scene_scale) # Model - feature_dim = 32 if cfg.app_opt else None self.splats, self.optimizers = create_splats_with_optimizers( self.parser, init_extent=cfg.init_extent, @@ -358,7 +348,6 @@ def __init__( scene_scale=self.scene_scale, sparse_grad=cfg.sparse_grad, batch_size=cfg.batch_size, - feature_dim=feature_dim, device=self.device, world_rank=world_rank, world_size=world_size, @@ -400,29 +389,6 @@ def __init__( if world_size > 1: self.pose_perturb = DDP(self.pose_perturb) - self.app_optimizers = [] - if cfg.app_opt: - assert feature_dim is not None - self.app_module = AppearanceOptModule( - len(self.trainset), feature_dim, cfg.app_embed_dim, None - ).to(self.device) - # initialize the last layer to be zero so that the initial output is zero. - torch.nn.init.zeros_(self.app_module.color_head[-1].weight) - torch.nn.init.zeros_(self.app_module.color_head[-1].bias) - self.app_optimizers = [ - torch.optim.Adam( - self.app_module.embeds.parameters(), - lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size) * 10.0, - weight_decay=cfg.app_opt_reg, - ), - torch.optim.Adam( - self.app_module.color_head.parameters(), - lr=cfg.app_opt_lr * math.sqrt(cfg.batch_size), - ), - ] - if world_size > 1: - self.app_module = DDP(self.app_module) - self.bil_grid_optimizers = [] if cfg.use_bilateral_grid: self.bil_grids = BilateralGrid( @@ -622,17 +588,7 @@ def rasterize_splats( rasterize_mode="antialiased" if self.cfg.antialiased else "classic", ) - image_ids = kwargs.pop("image_ids", None) - if self.cfg.app_opt: - colors = self.app_module( - features=self.splats["gauss_params"]["features"], - embed_ids=image_ids, - dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3], - ) - colors = colors + info["colors"] - colors = torch.sigmoid(colors) - else: - colors = info["colors"] # [N, K, 3] + colors = info["colors"] # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" render_colors, render_alphas, raster_info = rasterization( @@ -907,11 +863,6 @@ def train(self): data["pose_adjust"] = self.pose_adjust.module.state_dict() else: data["pose_adjust"] = self.pose_adjust.state_dict() - if cfg.app_opt: - if world_size > 1: - data["app_module"] = self.app_module.module.state_dict() - else: - data["app_module"] = self.app_module.state_dict() torch.save( data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) @@ -952,9 +903,6 @@ def train(self): for optimizer in self.pose_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) - for optimizer in self.app_optimizers: - optimizer.step() - optimizer.zero_grad(set_to_none=True) for optimizer in self.bil_grid_optimizers: optimizer.step() optimizer.zero_grad(set_to_none=True) From 7bbf30d3b87213b642565b349aae99249d35eb5d Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 09:40:06 +0200 Subject: [PATCH 21/29] update requirements --- examples/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/requirements.txt b/examples/requirements.txt index c91c96a92..0c69fa344 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -18,5 +18,6 @@ Pillow tensorboard tensorly pyyaml +torch_scatter matplotlib git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157 From 40dc7accff25a295bf75e11deecca73b003d8cd6 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 09:46:30 +0200 Subject: [PATCH 22/29] fix --- examples/simple_trainer_scaffold.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 1f584151f..b8fe4925b 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -49,12 +49,12 @@ class Config: # Disable viewer disable_viewer: bool = False # Path to the .pt files. If provide, it will skip training and run evaluation only. - # ckpt: Optional[List[str]] = None - ckpt: Optional[List[str]] = field( - default_factory=lambda: [ - "/home/paja/projects/gsplat_fork/results/ckpts/ckpt_1999_rank0.pt" - ] - ) + ckpt: Optional[List[str]] = None + # ckpt: Optional[List[str]] = field( + # default_factory=lambda: [ + # "/home/paja/projects/gsplat_fork/results/ckpts/ckpt_1999_rank0.pt" + # ] + # ) # Render trajectory path render_traj_path: str = "ellipse" @@ -729,7 +729,6 @@ def train(self): sh_degree=None, near_plane=cfg.near_plane, far_plane=cfg.far_plane, - image_ids=image_ids, render_mode="RGB+ED" if cfg.depth_loss else "RGB", ) if renders.shape[-1] == 4: From 6a55098063352efa928892273c596bd70697f4ab Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 10:03:20 +0200 Subject: [PATCH 23/29] fixes --- examples/benchmarks/scaffold.sh | 3 ++- examples/simple_trainer_scaffold.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/benchmarks/scaffold.sh b/examples/benchmarks/scaffold.sh index fc6811ba2..7452ad1bd 100644 --- a/examples/benchmarks/scaffold.sh +++ b/examples/benchmarks/scaffold.sh @@ -1,6 +1,7 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark" -SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +#SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers +SCENE_LIST="counter" # treehill flowers RENDER_TRAJ_PATH="ellipse" for SCENE in $SCENE_LIST; diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index b8fe4925b..26b9c43b2 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -31,7 +31,7 @@ from fused_ssim import fused_ssim from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity from typing_extensions import Literal -from utils import AppearanceOptModule, CameraOptModule, knn, set_random_seed +from utils import CameraOptModule, knn, set_random_seed from lib_bilagrid import ( BilateralGrid, slice, From 6bb591e6c06d9f585988503c56ddae80a85b75c8 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 10:07:48 +0200 Subject: [PATCH 24/29] build fix hopefully --- examples/requirements.txt | 1 - setup.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/requirements.txt b/examples/requirements.txt index 0c69fa344..c91c96a92 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -18,6 +18,5 @@ Pillow tensorboard tensorly pyyaml -torch_scatter matplotlib git+https://github.com/rahul-goel/fused-ssim@1272e21a282342e89537159e4bad508b19b34157 diff --git a/setup.py b/setup.py index ff0dd8e00..608fc4866 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,7 @@ def get_extensions(): "jaxtyping", "rich>=12", "torch", - "typing_extensions; python_version<'3.8'", + "torch_scatter" "typing_extensions; python_version<'3.8'", ], extras_require={ # dev dependencies. Install them by `pip install gsplat[dev]` From c4e4f6c1e934570c68e610cfaadfef8ede631ead Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 10:13:20 +0200 Subject: [PATCH 25/29] fix --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 608fc4866..fd8966d6a 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,8 @@ def get_extensions(): "jaxtyping", "rich>=12", "torch", - "torch_scatter" "typing_extensions; python_version<'3.8'", + "torch_scatter", + "typing_extensions; python_version<'3.8'", ], extras_require={ # dev dependencies. Install them by `pip install gsplat[dev]` From 9fafbb31f6e1501a35caf5886265ff26d91909ca Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 10:46:55 +0200 Subject: [PATCH 26/29] fixes --- .gitignore | 3 +++ examples/simple_trainer_scaffold.py | 8 ++++++-- gsplat/strategy/ops.py | 5 ++--- setup.py | 1 - 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 65a25f051..9ccfebea6 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ compile_commands.json # Visual Studio Code configs. .vscode/ +# Pycharm +.idea + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 26b9c43b2..328aac52b 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -910,8 +910,12 @@ def train(self): # eval the full set if step in [i - 1 for i in cfg.eval_steps]: - self.eval(step) - # self.render_traj(step) + self.eval( + step, + n_feat_offsets=self.cfg.n_feat_offsets, + feat_dim=self.cfg.feat_dim, + ) + self.render_traj(step) if not cfg.disable_viewer: self.viewer.lock.release() diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index c00b34a67..7c4731caf 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -4,7 +4,6 @@ import torch import torch.nn.functional as F from torch import Tensor -from torch_scatter import scatter_max from gsplat import quat_scale_to_covar_preci from gsplat.relocation import compute_relocation @@ -417,8 +416,8 @@ def grow_anchors( selected_features = repeated_features[gradient_mask] # [N_selected, feat_dim] # Use inverse_indices to aggregate features - scattered_features, _ = scatter_max( - selected_features, inv_idx.unsqueeze(1).expand(-1, feat_dim), dim=0 + scattered_features = torch.segment_reduce( + data=selected_features, reduce="amax", lengths=torch.bincount(inv_idx) ) feat = scattered_features[remove_duplicates_mask] # [N_new, feat_dim] diff --git a/setup.py b/setup.py index fd8966d6a..ff0dd8e00 100644 --- a/setup.py +++ b/setup.py @@ -118,7 +118,6 @@ def get_extensions(): "jaxtyping", "rich>=12", "torch", - "torch_scatter", "typing_extensions; python_version<'3.8'", ], extras_require={ From 139b52111c835db0785bcfd812417a96cb749652 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 11:40:56 +0200 Subject: [PATCH 27/29] clean up for review --- examples/benchmarks/scaffold.sh | 5 ++--- examples/simple_trainer_scaffold.py | 20 +------------------- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/examples/benchmarks/scaffold.sh b/examples/benchmarks/scaffold.sh index 7452ad1bd..6358077f5 100644 --- a/examples/benchmarks/scaffold.sh +++ b/examples/benchmarks/scaffold.sh @@ -1,7 +1,6 @@ SCENE_DIR="data/360_v2" RESULT_DIR="results/benchmark" -#SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers -SCENE_LIST="counter" # treehill flowers +SCENE_LIST="garden bicycle stump bonsai counter kitchen room" # treehill flowers RENDER_TRAJ_PATH="ellipse" for SCENE in $SCENE_LIST; @@ -51,4 +50,4 @@ do cat $STATS; echo done -done \ No newline at end of file +done diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 328aac52b..75efd2d94 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -50,20 +50,13 @@ class Config: disable_viewer: bool = False # Path to the .pt files. If provide, it will skip training and run evaluation only. ckpt: Optional[List[str]] = None - # ckpt: Optional[List[str]] = field( - # default_factory=lambda: [ - # "/home/paja/projects/gsplat_fork/results/ckpts/ckpt_1999_rank0.pt" - # ] - # ) # Render trajectory path render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset data_dir: str = "examples/data/360_v2/room" - # data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/tanksandtemples/truck/" - # data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset - data_factor: int = 4 + data_factor: int = 2 # Directory to save results result_dir: str = "results" # Every N images there is a test image @@ -766,17 +759,6 @@ def train(self): loss += ssimloss * cfg.ssim_lambda loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg - # Apply sigmoid to normalize values to [0, 1] - # sigmoid_opacities = torch.sigmoid(info["opacities"]) - # - # # Custom loss to penalize values not close to 0 or 1 - # def binarization_loss(x): - # return (x * (1 - x)).mean() - # - # # Calculate the binarization loss - # opa_loss = binarization_loss(sigmoid_opacities) - # loss += 0.01 * opa_loss - if cfg.depth_loss: # query depths from depth map points = torch.stack( From e8d1207bcef4bbca9e8ea742fea3802bccf2a410 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 27 Sep 2024 21:41:42 +0200 Subject: [PATCH 28/29] bug fixed, psrn 27.99 on garden scene --- examples/simple_trainer_scaffold.py | 4 ++-- gsplat/strategy/ops.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 75efd2d94..45af847c3 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,9 +54,9 @@ class Config: render_traj_path: str = "ellipse" # Path to the Mip-NeRF 360 dataset - data_dir: str = "examples/data/360_v2/room" + data_dir: str = "examples/data/360_v2/garden" # Downsample factor for the dataset - data_factor: int = 2 + data_factor: int = 4 # Directory to save results result_dir: str = "results" # Every N images there is a test image diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 7c4731caf..31affe6c2 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -284,15 +284,20 @@ def param_fn(name: str, p: Tensor) -> Tensor: def optimizer_fn(key: str, v: Tensor) -> Tensor: v[sampled_idxs] = 0 + v[dead_indices] = 0 return v # update the parameters and the state in the optimizers _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) # update the extra running state for k, v in state.items(): - if isinstance(v, torch.Tensor): + if isinstance(v, torch.Tensor) and k != "binoms": if k == "anchor_count" or k == "anchor_opacity": v[sampled_idxs] = 0 + v[dead_indices] = 0 + else: + v.view(-1, state["n_feat_offsets"])[sampled_idxs] = 0 + v.view(-1, state["n_feat_offsets"])[dead_indices] = 0 @torch.no_grad() From de35ba3d62b9cd8d6a298920112746e3977ba3a6 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Sat, 28 Sep 2024 22:15:26 +0200 Subject: [PATCH 29/29] fix merge break --- gsplat/rendering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gsplat/rendering.py b/gsplat/rendering.py index d263897a5..b94d736ef 100644 --- a/gsplat/rendering.py +++ b/gsplat/rendering.py @@ -597,7 +597,7 @@ def view_to_visible_anchors( sh_degree: Optional[int] = None, packed: bool = True, rasterize_mode: Literal["classic", "antialiased"] = "classic", - ortho: bool = False, + camera_model: Literal["pinhole", "ortho", "fisheye"] = "pinhole", covars: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Dict]: """Rasterize a set of 3D Gaussians (N) to a batch of image planes (C). @@ -773,7 +773,7 @@ def view_to_visible_anchors( radius_clip=radius_clip, sparse_grad=False, calc_compensations=(rasterize_mode == "antialiased"), - ortho=ortho, + camera_model=camera_model, ) if packed: