From 1c101f781566d3f266602ed0424e9229ee6c0b39 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Thu, 19 Sep 2024 11:40:14 +0200 Subject: [PATCH] implement the anchor pruning mechanism --- gsplat/strategy/ops.py | 11 +++- gsplat/strategy/scaffold.py | 128 +++++++++++++++++++++++++----------- 2 files changed, 96 insertions(+), 43 deletions(-) diff --git a/gsplat/strategy/ops.py b/gsplat/strategy/ops.py index 9b0fb81d..bedcb5a1 100644 --- a/gsplat/strategy/ops.py +++ b/gsplat/strategy/ops.py @@ -181,13 +181,16 @@ def remove( optimizers: Dict[str, torch.optim.Optimizer], state: Dict[str, Tensor], mask: Tensor, + names: Union[List[str], None] = None, ): """Inplace remove the Gaussian with the given mask. Args: params: A dictionary of parameters. optimizers: A dictionary of optimizers, each corresponding to a parameter. + state: A dictionary of extra state tensors. mask: A boolean mask to remove the Gaussians. + names: A list of key names to update. If None, update all. Default: None. """ sel = torch.where(~mask)[0] @@ -198,7 +201,7 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: return v[sel] # update the parameters and the state in the optimizers - _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers) + _update_param_with_optimizer(param_fn, optimizer_fn, params, optimizers, names) # update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): @@ -360,7 +363,6 @@ def op_sigmoid(x, k=100, x0=0.995): * (op_sigmoid(1 - opacities)).unsqueeze(-1) * scaler ) - noise = torch.einsum("bij,bj->bi", covars, noise) params["means"].add_(noise) @@ -449,6 +451,9 @@ def optimizer_fn(key: str, v: Tensor) -> Tensor: # Update the extra running state for k, v in state.items(): if isinstance(v, torch.Tensor): - zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) + if k == "anchor_count" or k == "anchor_opacity": + zeros = torch.zeros((num_new, *v.shape[1:]), device=device) + else: + zeros = torch.zeros((num_new * n_feat_offsets, *v.shape[1:]), device=device) state[k] = torch.cat([v, zeros], dim=0) diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index bd72d600..f14655fa 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -118,9 +118,10 @@ def initialize_state(self, scene_scale: float = 1.0) -> Dict[str, Any]: # - grad2d: running accum of the norm of the image plane gradients for each GS. # - count: running accum of how many time each GS is visible. # - radii: the radii of the GSs (normalized by the image resolution). - # - radii: the radii of the GSs (normalized by the image resolution). state = {"grad2d": None, "count": None, + "anchor_count": None, + "anchor_opacity": None, "scene_scale": scene_scale} if self.refine_scale2d_stop_iter > 0: state["radii"] = None @@ -204,51 +205,75 @@ def step_post_backward( ) # prune GSs - # n_prune = self._prune_gs(params, optimizers, state, step) - # if self.verbose: - # print( - # f"Step {step}: {n_prune} GSs pruned. " - # f"Now having {len(params['anchors'])} GSs." - # ) + is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() + mask = state["anchor_opacity"] > self.prune_opa + print(mask.sum().item()) + + n_prune = is_prune.sum().item() + if n_prune > 0: + names = ["anchors", "scales", "quats", "features", "offsets"] + remove(params=params, optimizers=optimizers, state=state, mask=is_prune, names=names) + + if self.verbose: + print( + f"Step {step}: {n_prune} GSs pruned. " + f"Now having {len(params['anchors'])} GSs." + ) # reset running stats state["grad2d"].zero_() state["count"].zero_() + state["anchor_count"].zero_() + state["anchor_opacity"].zero_() + device = params["anchors"].device + n_gaussian = params["anchors"].shape[0] + state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + state["anchor_count"] = torch.zeros(n_gaussian, device=device) + state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) if self.refine_scale2d_stop_iter > 0: state["radii"].zero_() torch.cuda.empty_cache() def _update_state( - self, - params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], - state: Dict[str, Any], - info: Dict[str, Any], - packed: bool = False, + self, + params: Union[Dict[str, torch.nn.Parameter], torch.nn.ParameterDict], + state: Dict[str, Any], + info: Dict[str, Any], + packed: bool = False, ): - for key in ["width", "height", "n_cameras", "radii", "gaussian_ids"]: + # Ensure required keys are present + required_keys = ["width", "height", "n_cameras", "radii", "gaussian_ids"] + for key in required_keys: assert key in info, f"{key} is required but missing." - # normalize grads to [-1, 1] screen space + # Normalize gradients to [-1, 1] screen space + scale_factors = torch.tensor( + [info["width"] / 2.0 * info["n_cameras"], info["height"] / 2.0 * info["n_cameras"]], + device=info["means2d"].device, + ) + if self.absgrad: - grads = info["means2d"].absgrad.clone() + grads = info["means2d"].absgrad.detach() * scale_factors else: - grads = info["means2d"].grad.clone() - grads[..., 0] *= info["width"] / 2.0 * info["n_cameras"] - grads[..., 1] *= info["height"] / 2.0 * info["n_cameras"] + grads = info["means2d"].grad.detach() * scale_factors - # initialize state on the first run + # Initialize state on the first run n_gaussian = params["anchors"].shape[0] + device = grads.device if state["grad2d"] is None: - state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) + state["grad2d"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) if state["count"] is None: - state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) + state["count"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) + if state["anchor_count"] is None: + state["anchor_count"] = torch.zeros(n_gaussian, device=device) + if state["anchor_opacity"] is None: + state["anchor_opacity"] = torch.zeros(n_gaussian, device=device) if self.refine_scale2d_stop_iter > 0 and state["radii"] is None: - assert "radii" in info, "radii is required but missing." - state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=grads.device) - + state["radii"] = torch.zeros(n_gaussian * self.n_feat_offsets, device=device) - # update the running state + # Update the running state if packed: # grads is [nnz, 2] gs_ids = info["gaussian_ids"] # [nnz] @@ -256,31 +281,54 @@ def _update_state( else: # grads is [C, N, 2] sel = info["radii"] > 0.0 # [C, N] - gs_ids = torch.where(sel)[1] # [nnz] + gs_ids = sel.nonzero(as_tuple=False)[:, 1] # [nnz] grads = grads[sel] # [nnz, 2] radii = info["radii"][sel] # [nnz] - # update neural gaussian statis + + # Compute valid_mask efficiently visible_anchor_mask = info["visible_anchor_mask"] neural_selection_mask = info["neural_selection_mask"] - # Extend to - anchor_visible_mask = visible_anchor_mask.unsqueeze(dim=1).repeat([1, self.n_feat_offsets]).view(-1) - neural_gaussian_mask = torch.zeros_like(state["grad2d"], dtype=torch.bool) - neural_gaussian_mask[anchor_visible_mask] = neural_selection_mask - valid_mask = neural_gaussian_mask[gs_ids] - # Filter gs_ids and grads based on the valid_mask + # Compute anchor indices + anchor_indices = gs_ids // self.n_feat_offsets + + # Determine valid gs_ids based on visibility and selection masks + valid_mask = ( + visible_anchor_mask[anchor_indices] & neural_selection_mask[gs_ids] + ) + + # Filter gs_ids and grads based on valid_mask valid_gs_ids = gs_ids[valid_mask] - valid_grads_norm = grads.norm(dim=-1)[valid_mask] + valid_grads_norm = grads[valid_mask].norm(dim=-1) + # Update state using index_add_ state["grad2d"].index_add_(0, valid_gs_ids, valid_grads_norm) - state["count"].index_add_(0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32)) - + state["count"].index_add_( + 0, valid_gs_ids, torch.ones_like(valid_gs_ids, dtype=torch.float32) + ) + + # Update anchor opacity and count + anchor_ids = visible_anchor_mask.nonzero(as_tuple=False).squeeze(-1) + neural_opacities = ( + info["neural_opacities"] + .detach() + .view(-1, self.n_feat_offsets) + .clamp_min_(0) + .sum(dim=1) + ) + + state["anchor_opacity"].index_add_(0, anchor_ids, neural_opacities) + state["anchor_count"].index_add_( + 0, anchor_ids, torch.ones_like(anchor_ids, dtype=torch.float32) + ) + + # Update radii if required if self.refine_scale2d_stop_iter > 0: - # Should be ideally using scatter max + # Normalize radii to [0, 1] screen space + normalized_radii = radii / float(max(info["width"], info["height"])) + # Update radii using torch.maximum state["radii"][gs_ids] = torch.maximum( - state["radii"][gs_ids], - # normalize radii to [0, 1] screen space - radii / float(max(info["width"], info["height"])), + state["radii"][gs_ids], normalized_radii ) @torch.no_grad()