diff --git a/examples/simple_trainer_scaffold.py b/examples/simple_trainer_scaffold.py index 9debd839e..4e612bf9a 100644 --- a/examples/simple_trainer_scaffold.py +++ b/examples/simple_trainer_scaffold.py @@ -54,7 +54,8 @@ class Config: render_traj_path: str = "interp" # Path to the Mip-NeRF 360 dataset - data_dir: str = "/home/paja/data/bike" + #data_dir: str = "/home/paja/.cache/nerfbaselines/datasets/mipnerf360/garden" + data_dir: str = "/home/paja/data/bike_aliked" # Downsample factor for the dataset data_factor: int = 1 # Directory to save results @@ -88,8 +89,6 @@ class Config: # Initial extent of GSs as a multiple of the camera extent. Ignored if using sfm init_extent: float = 3.0 # Turn on another SH degree every this steps - sh_degree_interval: int = 1000 - # Initial opacity of GS init_opa: float = 0.1 # Initial scale of GS init_scale: float = 1.0 @@ -113,8 +112,6 @@ class Config: # Use random background for training to discourage transparency random_bkgd: bool = False - # Opacity regularization - opacity_reg: float = 0.0 # Scale regularization scale_reg: float = 0.01 @@ -157,7 +154,6 @@ def adjust_steps(self, factor: float): self.eval_steps = [int(i * factor) for i in self.eval_steps] self.save_steps = [int(i * factor) for i in self.save_steps] self.max_steps = int(self.max_steps * factor) - self.sh_degree_interval = int(self.sh_degree_interval * factor) strategy = self.strategy strategy.refine_start_iter = int(strategy.refine_start_iter * factor) @@ -196,7 +192,6 @@ def create_splats_with_optimizers( # Distribute the GSs to different ranks (also works for single rank) points = points[world_rank::world_size] scales = scales[world_rank::world_size] - N = points.shape[0] quats = torch.rand((N, 4)) # [N, 4] @@ -209,8 +204,8 @@ def create_splats_with_optimizers( ("scales", torch.nn.Parameter(scales), 0.007), ("quats", torch.nn.Parameter(quats), 0.002), ("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002), - ("features", torch.nn.Parameter(features), 0.0075 * scene_scale), - ("offsets", torch.nn.Parameter(offsets), 0.01), + ("features", torch.nn.Parameter(features), 0.0075), + ("offsets", torch.nn.Parameter(offsets), 0.01 * scene_scale), ("colors_mlp", strategy.colors_mlp.parameters(), 0.008), ("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004), ] @@ -242,6 +237,7 @@ def __init__( set_random_seed(42 + local_rank) self.cfg = cfg + self.cfg.strategy.voxel_size = self.cfg.voxel_size self.world_rank = world_rank self.local_rank = local_rank self.world_size = world_size @@ -395,24 +391,20 @@ def __init__( mode="training", ) - def get_visible_anchor_mask( - self, - camtoworlds: Tensor, - Ks: Tensor, - width: int, - height: int, - packed: bool, - rasterize_mode: str, + def get_neural_gaussians(self, + camtoworlds: Tensor, + Ks: Tensor, + width: int, + height: int, + packed: bool, + rasterize_mode: str, ): - anchors = self.splats["anchors"] # [N, 3] - # rasterization does normalization internally - quats = self.splats["quats"] # [N, 4] - scales = torch.exp(self.splats["scales"])[:, :3] # [N, 3] + # Compare paper: Helps mainly to speed up the rasterization. Has no quality impact visible_anchor_mask = view_to_visible_anchors( - means=anchors, - quats=quats, - scales=scales, + means=self.splats["anchors"], + quats=self.splats["quats"], + scales=torch.exp(self.splats["scales"])[:, :3], viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] width=width, @@ -420,26 +412,23 @@ def get_visible_anchor_mask( packed=packed, rasterize_mode=rasterize_mode, ) - - return visible_anchor_mask - - def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None): - # If no visibility mask is provided, we select all anchors including their offsets - if visible_anchor_mask is None: - visible_anchor_mask = torch.ones(self.splats["anchors"].shape[0], dtype=torch.bool, device=self.device) - selected_features = self.splats["features"][visible_anchor_mask] # [M, c] selected_anchors = self.splats["anchors"][visible_anchor_mask] # [M, 3] selected_offsets = self.splats["offsets"][visible_anchor_mask] # [M, k, 3] selected_scales = torch.exp(self.splats["scales"][visible_anchor_mask]) # [M, 6] # See formula (5) in Scaffold-GS + + cam_pos = camtoworlds[:, :3, 3] view_dir = selected_anchors - cam_pos # [M, 3] - view_dir_normalized = view_dir / view_dir.norm(dim=1, keepdim=True) # [M, 3] + length = view_dir.norm(dim=1, keepdim=True) + view_dir_normalized = view_dir / length # [M, 3] + + view_length = torch.cat([view_dir_normalized, length], dim=1) # See formula (9) and the appendix for the rest - feature_view_dir = torch.cat([selected_features, view_dir_normalized], dim=1) # [M, c+3] + feature_view_dir = torch.cat([selected_features, view_length], dim=1) # [M, c+3] k = self.cfg.strategy.n_feat_offsets # Number of offsets per anchor @@ -462,22 +451,36 @@ def get_neural_gaussians(self, cam_pos, visible_anchor_mask=None): anchors_repeated = selected_anchors.unsqueeze(1).repeat(1, k, 1).view(-1, 3) # [M*k, 3] # Apply positive opacity mask - selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [m] - selected_colors = neural_colors[neural_selection_mask] # [m, 3] - selected_scale_rot = neural_scale_rot[neural_selection_mask] # [m, 7] - selected_offsets = selected_offsets[neural_selection_mask] # [m, 3] - scales_repeated = scales_repeated[neural_selection_mask] # [m, 6] - anchors_repeated = anchors_repeated[neural_selection_mask] # [m, 3] + selected_opacity = neural_opacity[neural_selection_mask].squeeze(-1) # [M] + selected_colors = neural_colors[neural_selection_mask] # [M, 3] + selected_scale_rot = neural_scale_rot[neural_selection_mask] # [M, 7] + selected_offsets = selected_offsets[neural_selection_mask] # [M, 3] + scales_repeated = scales_repeated[neural_selection_mask] # [M, 6] + anchors_repeated = anchors_repeated[neural_selection_mask] # [M, 3] # Compute scales and rotations - scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [m, 3] - rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [m, 4] + scales = scales_repeated[:, 3:] * torch.sigmoid(selected_scale_rot[:, :3]) # [M, 3] + rotation = torch.nn.functional.normalize(selected_scale_rot[:, 3:7]) # [M, 4] # Compute offsets and anchors - offsets = selected_offsets * scales_repeated[:, :3] # [m, 3] - means = anchors_repeated + offsets # [m, 3] - - return means, selected_colors, selected_opacity, scales, rotation, neural_opacity, neural_selection_mask + offsets = selected_offsets * scales_repeated[:, :3] # [M, 3] + means = anchors_repeated + offsets # [M, 3] + + v_a = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.cfg.strategy.n_feat_offsets]).view(-1) + all_neural_gaussians = torch.zeros_like(v_a, dtype=torch.bool) + all_neural_gaussians[v_a] = neural_selection_mask + + info = { + "means": means, + "colors": selected_colors, + "opacities": selected_opacity, + "scales": scales, + "quats": rotation, + "neural_opacities": neural_opacity, + "neural_selection_mask":all_neural_gaussians, + "visible_anchor_mask": visible_anchor_mask, + } + return info def rasterize_splats( self, @@ -486,38 +489,35 @@ def rasterize_splats( width: int, height: int, **kwargs, - ) -> Tuple[Tensor, Tensor, Dict, Tensor]: - - # We select only the visible anchors for faster inference - visible_anchor_mask = self.get_visible_anchor_mask(camtoworlds=camtoworlds, - Ks=Ks, - width=width, - height=height, - packed=self.cfg.packed, - rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", - ) + ) -> Tuple[Tensor, Tensor, Dict]: # Get all the gaussians per voxel spawned from the anchors - means, color_mlp, opacities, scales, quats, neural_opacity, neural_selection_mask = self.get_neural_gaussians(camtoworlds[:, :3, 3], visible_anchor_mask=visible_anchor_mask) + info = self.get_neural_gaussians(camtoworlds=camtoworlds, + Ks=Ks, + width=width, + height=height, + packed=self.cfg.packed, + rasterize_mode = "antialiased" if self.cfg.antialiased else "classic", + ) image_ids = kwargs.pop("image_ids", None) if self.cfg.app_opt: colors = self.app_module( features=self.splats["features"], embed_ids=image_ids, - dirs=means[None, :, :] - camtoworlds[:, None, :3, 3], + dirs=info["means"][None, :, :] - camtoworlds[:, None, :3, 3], ) - colors = colors + color_mlp + colors = colors + info["colors"] colors = torch.sigmoid(colors) else: - colors = color_mlp # [N, K, 3] + colors = info["colors"] # [N, K, 3] rasterize_mode = "antialiased" if self.cfg.antialiased else "classic" - render_colors, render_alphas, info = rasterization( - means=means, - quats=quats, - scales=scales, - opacities=opacities, + render_colors, render_alphas, raster_info = rasterization( + means=info["means"], + quats=info["quats"], + scales=info["scales"], + opacities=info["opacities"], colors=colors, viewmats=torch.linalg.inv(camtoworlds), # [C, 4, 4] Ks=Ks, # [C, 3, 3] @@ -530,14 +530,10 @@ def rasterize_splats( distributed=self.world_size > 1, **kwargs, ) - info.update( - { - "visible_anchor_mask": visible_anchor_mask, - "neural_selection_mask": neural_selection_mask, - "neural_opacities": neural_opacity, - } + raster_info.update( + info ) - return render_colors, render_alphas, info, scales + return render_colors, render_alphas, raster_info def train(self): cfg = self.cfg @@ -555,13 +551,16 @@ def train(self): schedulers = [ torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["offsets"], gamma=0.01 ** (1.0 / max_steps) + self.optimizers["anchors"], gamma=0.001 ** (1.0 / max_steps) + ), + torch.optim.lr_scheduler.ExponentialLR( + self.optimizers["offsets"], gamma=(0.01 * self.scene_scale) ** (1.0 / max_steps) ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["opacities_mlp"], gamma=0.002 ** (1.0 / max_steps) + self.optimizers["opacities_mlp"], gamma=0.001 ** (1.0 / max_steps) ), torch.optim.lr_scheduler.ExponentialLR( - self.optimizers["colors_mlp"], gamma=0.008 ** (1.0 / max_steps) + self.optimizers["colors_mlp"], gamma=0.00625 ** (1.0 / max_steps) ), ] if cfg.pose_opt: @@ -601,6 +600,11 @@ def train(self): # Training loop. global_tic = time.time() pbar = tqdm.tqdm(range(init_step, max_steps)) + + self.cfg.strategy.scale_rot_mlp.train() + self.cfg.strategy.opacities_mlp.train() + self.cfg.strategy.colors_mlp.train() + for step in pbar: if not cfg.disable_viewer: while self.viewer.state.status == "paused": @@ -634,7 +638,7 @@ def train(self): camtoworlds = self.pose_adjust(camtoworlds, image_ids) # forward - renders, alphas, info, scales = self.rasterize_splats( + renders, alphas, info = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -678,7 +682,7 @@ def train(self): ) loss = l1loss * (1.0 - cfg.ssim_lambda) loss += ssimloss * cfg.ssim_lambda - loss += scales.prod(dim=1).mean() * cfg.scale_reg + loss += info["scales"].prod(dim=1).mean() * cfg.scale_reg if cfg.depth_loss: # query depths from depth map points = torch.stack( @@ -842,7 +846,7 @@ def eval(self, step: int, stage: str = "val"): torch.cuda.synchronize() tic = time.time() - colors, _, _, _ = self.rasterize_splats( + colors, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -946,7 +950,7 @@ def render_traj(self, step: int): camtoworlds = camtoworlds_all[i : i + 1] Ks = K[None] - renders, _, _, _ = self.rasterize_splats( + renders, _, _ = self.rasterize_splats( camtoworlds=camtoworlds, Ks=Ks, width=width, @@ -1003,7 +1007,7 @@ def _viewer_render_fn( c2w = torch.from_numpy(c2w).float().to(self.device) K = torch.from_numpy(K).float().to(self.device) - render_colors, _, _, _ = self.rasterize_splats( + render_colors, _, _ = self.rasterize_splats( camtoworlds=c2w[None], Ks=K[None], width=W, diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index bedcb5a1b..e4b11aa88 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -195,7 +195,14 @@ def remove( sel = torch.where(~mask)[0] def param_fn(name: str, p: Tensor) -> Tensor: - return torch.nn.Parameter(p[sel]) + if name == "scales": + p = p[sel] + temp = p[:,3:] + temp[temp>0.05] = 0.05 + p[:,3:] = temp + return torch.nn.Parameter(p) + else: + return torch.nn.Parameter(p[sel]) def optimizer_fn(key: str, v: Tensor) -> Tensor: return v[sel] @@ -403,8 +410,7 @@ def grow_anchors( log_voxel_size = torch.log(torch.tensor(voxel_size, device=device)) scaling = log_voxel_size.expand(num_new, anchors.size(1) * 2) # [N_new, 6] - rotation = torch.zeros((num_new, 4), device=device) - rotation[:, 0] = 1.0 # Identity quaternion + rotation = torch.ones((num_new, 4), device=device) # Prepare new features existing_features = params["features"] # [N_existing, feat_dim] @@ -457,3 +463,40 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) state[k] = torch.cat([v, zeros], dim=0) +@torch.no_grad() +def remove_anchors( + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + optimizers: Dict[str, torch.optim.Optimizer], + n_feat_offsets: int, + state: Dict[str, Tensor], + mask: Tensor, + names: Union[List[str], None] = None, +): + """Inplace remove the Gaussian with the given mask. + + Args: + params: A dictionary of parameters. + optimizers: A dictionary of optimizers, each corresponding to a parameter. + n_feat_offsets: Number of feature offsets. + state: A dictionary of extra state tensors. + mask: A boolean mask to remove the Gaussians. + names: A list of parameter names to update. If None, update all. Default: None. + """ + sel = torch.where(~mask)[0] + + def param_fn(name: str, p: Tensor) -> Tensor: + return torch.nn.Parameter(p[sel]) + + def optimizer_fn(key: str, v: Tensor) -> Tensor: + return v[sel] + + # update the parameters and the state in the optimizers + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) + # update the extra running state + for k, v in state.items(): + if isinstance(v, torch.Tensor): + if k in ["anchor_count", "anchor_opacity"]: + state[k] = v[sel] + else: + offset_sel = sel.unsqueeze(dim=1).repeat([1, n_feat_offsets]).view(-1) + state[k] = v[offset_sel] diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index 740ab51b7..086ddfd18 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -1,10 +1,9 @@ from dataclasses import dataclass -from typing import Any, Dict, Tuple, Union - +from typing import Any, Dict, Union import torch from .base import Strategy -from .ops import duplicate, remove, grow_anchors +from .ops import remove_anchors, grow_anchors @dataclass @@ -41,8 +40,6 @@ class ScaffoldStrategy(Strategy): value will be pruned. Default is 0.1. prune_scale2d (float): GSs with 2d scale (normalized by image resolution) above this value will be pruned. Default is 0.15. - refine_scale2d_stop_iter (int): Stop refining GSs based on 2d scale after this - iteration. Default is 0. Set to a positive value to enable this feature. refine_start_iter (int): Start refining GSs after this iteration. Default is 500. refine_stop_iter (int): Stop refining GSs after this iteration. Default is 15_000. refine_every (int): Refine GSs every this steps. Default is 100. @@ -77,8 +74,7 @@ class ScaffoldStrategy(Strategy): grow_scale2d: float = 0.05 prune_scale3d: float = 0.1 prune_scale2d: float = 0.15 - refine_scale2d_stop_iter: int = 0 - refine_start_iter: int = 500 + refine_start_iter: int = 1500 max_voxel_levels: int = 3 voxel_size: float = 0.001 refine_stop_iter: int = 15_000 @@ -96,22 +92,23 @@ class ScaffoldStrategy(Strategy): revised_opacity: bool = False verbose: bool = True + view_distance = 1 colors_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 3 * n_feat_offsets), torch.nn.Sigmoid() ).cuda() opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh() ).cuda() scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( - torch.nn.Linear(feat_dim + 3, feat_dim), + torch.nn.Linear(feat_dim + 3 + 1, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 7 * n_feat_offsets) ).cuda() @@ -126,14 +123,11 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # put them on the correct device. # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. - # - radii: the radii of the GSs (normalized by the image resolution). state = {"grad2d": None, "count": None, "anchor_count": None, "anchor_opacity": None, "scene_scale": scene_scale} - if self.refine_scale2d_stop_iter > 0: - state["radii"] = None return state def check_sanity( @@ -166,7 +160,8 @@ def check_sanity( "quats", "opacities_mlp", "colors_mlp", - "scale_rot_mlp"] + "scale_rot_mlp", + ] assert len(expected_params) == len(params), "expected params and actual params don't match" for key in expected_params: @@ -199,28 +194,36 @@ def step_post_backward( if step >= self.refine_stop_iter: return - self._update_state(params, state, info, packed=packed) + if step > 500: + self._update_state(params, state, info, packed=packed) if ( step > self.refine_start_iter and step % self.refine_every == 0 ): # grow GSs - new_anchors = self._anchor_growing(params, optimizers, state, step) + print(f"init anchors: {len(params['anchors'])}") + new_anchors = self._anchor_growing(params, optimizers, state) if self.verbose: print( f"Step {step}: {new_anchors} anchors grown." f"Now having {len(params['anchors'])} anchors." ) - # prune GSs - is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() - is_prune = torch.logical_and(is_prune, state["anchor_count"] > self.refine_every * self.pruning_thresholds) + # prune anchors + low_opacity_mask = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() + anchor_mask = (state["anchor_count"] > self.pruning_thresholds * self.refine_every) # [N, 1] + is_prune = torch.logical_and(low_opacity_mask, anchor_mask) + + indices = anchor_mask.nonzero(as_tuple=False).squeeze() + # Efficiently set the specified indices to zero + state["anchor_count"].index_fill_(0, indices, 0) + state["anchor_opacity"].index_fill_(0, indices, 0) n_prune = is_prune.sum().item() if n_prune > 0: names = ["anchors", "scales", "quats", "features", "offsets"] - remove(params=params, optimizers=optimizers, state=state, mask=is_prune, names=names) + remove_anchors(params=params, optimizers=optimizers, n_feat_offsets=self.n_feat_offsets, state=state, mask=is_prune, names=names) if self.verbose: print( @@ -228,20 +231,6 @@ def step_post_backward( f"Now having {len(params['anchors'])} anchors." ) - # reset running stats - state["grad2d"].zero_() - state["count"].zero_() - state["anchor_count"].zero_() - state["anchor_opacity"].zero_() - device = params["anchors"].device - n_gaussian = params["anchors"].shape[0] - state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - state["anchor_count"] = torch.zeros(n_gaussian, device=device) - state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) - - if self.refine_scale2d_stop_iter > 0: - state["radii"].zero_() torch.cuda.empty_cache() def _update_state( @@ -257,8 +246,9 @@ def _update_state( assert key in info, f"{key} is required but missing." # Normalize gradients to [-1, 1] screen space + factor = 0.5 * info["n_cameras"] scale_factors = torch.tensor( - [info["width"] / 2.0 * info["n_cameras"], info["height"] / 2.0 * info["n_cameras"]], + [info["width"] * factor, info["height"] * factor], device=info["means2d"].device, ) @@ -278,44 +268,25 @@ def _update_state( state["anchor_count"] = torch.zeros(n_gaussian, device=device) if state["anchor_opacity"] is None: state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) - if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: - state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - - # Update the running state - if packed: - # grads is [nnz, 2] - gs_ids = info["gaussian_ids"] # [nnz] - radii = info["radii"] # [nnz] - else: - # grads is [C, N, 2] - sel = info["radii"] > 0.0 # [C, N] - gs_ids = sel.nonzero(as_tuple=False)[:, 1] # [nnz] - grads = grads[sel] # [nnz, 2] - radii = info["radii"][sel] # [nnz] - # Compute valid_mask efficiently - visible_anchor_mask = info["visible_anchor_mask"] neural_selection_mask = info["neural_selection_mask"] - # Compute anchor indices - anchor_indices = gs_ids // self.n_feat_offsets - - # Determine valid gs_ids based on visibility and selection masks - valid_mask = ( - visible_anchor_mask[anchor_indices] & neural_selection_mask[gs_ids] - ) + # Update the running state + sel = info["radii"] > 0.0 # [C, N] + neural_ids = neural_selection_mask.nonzero(as_tuple=False).squeeze(-1) + grads = grads[sel].norm(dim=-1) # [nnz, 2] + sel = sel.squeeze(0) - # Filter gs_ids and grads based on valid_mask - valid_gs_ids = gs_ids[valid_mask] - valid_grads_norm = grads[valid_mask].norm(dim=-1) + valid_ids = neural_ids[sel] # Update state using index_add_ - state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm) + state["grad2d"].index_add_(0, valid_ids, grads) state["count"].index_add_( - 0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32) + 0, valid_ids, torch.ones((valid_ids.shape[0]), dtype=torch.float32, device=device) ) # Update anchor opacity and count + visible_anchor_mask = info["visible_anchor_mask"] anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1) neural_opacities = ( info["neural_opacities"] @@ -330,23 +301,13 @@ def _update_state( 0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32) ) - # Update radii if required - if self.refine_scale2d_stop_iter > 0: - # Normalize radii to [0, 1] screen space - normalized_radii = radii / float(max(info["width"], info["height"])) - # Update radii using torch.maximum - state["radii"][gs_ids] = torch.maximum( - state["radii"][gs_ids], normalized_radii - ) - @torch.no_grad() def _anchor_growing( self, params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Any], - step: int, - ) -> Tuple[int]: + ) -> int: """ Implements the Anchor Growing algorithm as described in Algorithm 1 of the GS-Scaffold appendix: @@ -362,12 +323,10 @@ def _anchor_growing( A dictionary of optimizers associated with the model's parameters. state (Dict[str, Any]): A dictionary containing the current state of the training process. - step (int): - The current step or iteration of the training process. Returns: - Tuple[int]: - A tuple containing the updated step value after anchor growing. + int: + Number of growned anchors. """ count = state["count"] @@ -375,15 +334,7 @@ def _anchor_growing( grads[grads.isnan()] = 0.0 device = grads.device - # is_grad_high = (grads > self.grow_grad2d).squeeze(-1) - # is_small = (n - # torch.exp(params["scales"]).max(dim=-1).values - # <= self.grow_scale3d * state["scene_scale"] - # ) - # is_dupli = is_grad_high & is_small - # n_dupli = is_dupli.sum().item() - - n_init_features = params["features"].shape[0] * self.n_feat_offsets + n_init_features = state["count"].shape[0] n_added_anchors = 0 # Algorithm 1: Anchor Growing @@ -394,39 +345,80 @@ def _anchor_growing( epsilon_g = self.voxel_size new_anchors: torch.tensor - growing_treshold = state["count"] > self.refine_every * self.growing_thresholds + growing_threshold_mask = state["count"] > self.refine_every * self.growing_thresholds # Step 2: Iterate while m <= M while m <= M: - n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features + n_feature_diff = state["count"].shape[0] - n_init_features # Check if anchor candidates have grown - if n_feature_diff == 0 and (m-1) > 0: + if n_feature_diff == 0 and m > 1: break # Step 3: Update threshold and voxel size - tau = tau_g * 2 ** (m - 1) + tau = tau_g * (2 ** (m - 1)) current_voxel_size = (16 // (4 ** (m - 1))) * epsilon_g # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) gradient_mask = grads >= tau - # - gradient_mask = torch.logical_and(gradient_mask, growing_treshold) + gradient_mask = torch.logical_and(gradient_mask, growing_threshold_mask) + + # Drop-out: Helps prevent too massive anchor growth. + rand_mask = torch.rand_like(gradient_mask.float())>(0.4**m) + rand_mask = rand_mask.cuda() + gradient_mask = torch.logical_and(gradient_mask, rand_mask) gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0) - neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1) + + # Compute neural gaussians + neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"][:,:3]).unsqueeze(dim=1) selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask] # Step 5: Merge same positions selected_grid_coords = torch.round(selected_neural_gaussians / current_voxel_size).int() selected_grid_coords_unique, inv_idx = torch.unique(selected_grid_coords, return_inverse=True, dim=0) - # Step 6: Random elimination - # TODO: implement rand elimination (necessary)? - # Get the grid coordinates of the current anchors grid_anchor_coords = torch.round(params["anchors"] / current_voxel_size).int() - # Step 7: Remove occupied by comparing the unique coordinates to current anchors - remove_occupied_pos_mask = ~( - (selected_grid_coords_unique.unsqueeze(1) == grid_anchor_coords).all(-1).any(-1).view(-1)) + # Step 6: Remove occupied by comparing the unique coordinates to current anchors + def coords_to_indices(coords, N): + """ + Maps quantized multi-dimensional coordinates to unique 1D indices without using torch.matmul. + + Args: + coords (torch.Tensor): Tensor of shape [num_points, D], where D is the dimensionality. + N (int): A large enough number to ensure uniqueness of indices. + + Returns: + torch.Tensor: Tensor of unique indices corresponding to the coordinates. + """ + D = coords.shape[1] + device = coords.device + dtype = coords.dtype # Keep the original data type + + # Compute N_powers as integers + N_powers = N ** torch.arange(D - 1, -1, -1, device=device, dtype=dtype) + + # Perform element-wise multiplication and sum along the last dimension + indices = (coords * N_powers).sum(dim=1) + + return indices + + # Quantize the coordinates + decimal_places = 6 # precision + scale = 10 ** decimal_places + + # Quantize the coordinates by scaling and converting to integers + selected_coords_quant = (selected_grid_coords_unique * scale).long() + anchor_coords_quant = (grid_anchor_coords * scale).long() + + # Compute the maximum coordinate value + max_coord_value = torch.max(torch.cat([selected_coords_quant, anchor_coords_quant])) + 1 + N = max_coord_value.item() + + # Compute unique indices for both coordinate sets + indices_selected = coords_to_indices(selected_coords_quant, N) + indices_anchor = coords_to_indices(anchor_coords_quant, N) + + remove_occupied_pos_mask = ~torch.isin(indices_selected, indices_anchor) # New anchor candidates are those unique coordinates that are not duplicates new_anchors = selected_grid_coords_unique[remove_occupied_pos_mask] @@ -446,34 +438,10 @@ def _anchor_growing( n_added_anchors += new_anchors.shape[0] m += 1 - return n_added_anchors + indices = torch.arange(n_init_features, device=growing_threshold_mask.device)[growing_threshold_mask] - @torch.no_grad() - def _prune_gs( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - optimizers: Dict[str, torch.optim.Optimizer], - state: Dict[str, Any], - step: int, - ) -> int: - is_prune = torch.sigmoid(params["opacities"].flatten()) < self.prune_opa - if step > self.reset_every: - is_too_big = ( - torch.exp(params["scales"]).max(dim=-1).values - > self.prune_scale3d * state["scene_scale"] - ) - # The official code also implements sreen-size pruning but - # it's actually not being used due to a bug: - # https://github.com/graphdeco-inria/gaussian-splatting/issues/123 - # We implement it here for completeness but set `refine_scale2d_stop_iter` - # to 0 by default to disable it. - if step < self.refine_scale2d_stop_iter: - is_too_big |= state["radii"] > self.prune_scale2d - - is_prune = is_prune | is_too_big - - n_prune = is_prune.sum().item() - if n_prune > 0: - remove(params=params, optimizers=optimizers, state=state, mask=is_prune) - - return n_prune + if indices.numel() > 0: + state["count"].index_fill_(0, indices, 0) + state["grad2d"].index_fill_(0, indices, 0) + + return n_added_anchors \ No newline at end of file