From 3423efb13d9896bedf91cd7c6f2309e3af957614 Mon Sep 17 00:00:00 2001 From: Janusch Patas Date: Fri, 20 Sep 2024 10:36:09 +0200 Subject: [PATCH] add observation threshold --- gsplat/strategy/scaffold.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/gsplat/strategy/scaffold.py b/gsplat/strategy/scaffold.py index f14655fa..740ab51b 100644 --- a/gsplat/strategy/scaffold.py +++ b/gsplat/strategy/scaffold.py @@ -85,22 +85,31 @@ class ScaffoldStrategy(Strategy): feat_dim: int = 32 n_feat_offsets: int = 10 refine_every: int = 100 + + # 3.3 Observation Thresholds (compare paper) + # "To enhance the robustness of the Growing and Pruning operations for long image sequences, ..." + pruning_thresholds: float = 0.8 + growing_thresholds: float = 0.4 + pause_refine_after_reset: int = 0 absgrad: bool = False revised_opacity: bool = False verbose: bool = True + colors_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, 3 * n_feat_offsets), torch.nn.Sigmoid() ).cuda() + opacities_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), torch.nn.Linear(feat_dim, n_feat_offsets), torch.nn.Tanh() ).cuda() + scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential( torch.nn.Linear(feat_dim + 3, feat_dim), torch.nn.ReLU(True), @@ -201,13 +210,12 @@ def step_post_backward( if self.verbose: print( f"Step {step}: {new_anchors} anchors grown." - f"Now having {len(params['anchors'])} GSs." + f"Now having {len(params['anchors'])} anchors." ) # prune GSs is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze() - mask = state["anchor_opacity"] > self.prune_opa - print(mask.sum().item()) + is_prune = torch.logical_and(is_prune, state["anchor_count"] > self.refine_every * self.pruning_thresholds) n_prune = is_prune.sum().item() if n_prune > 0: @@ -216,8 +224,8 @@ def step_post_backward( if self.verbose: print( - f"Step {step}: {n_prune} GSs pruned. " - f"Now having {len(params['anchors'])} GSs." + f"Step {step}: {n_prune} anchors pruned. " + f"Now having {len(params['anchors'])} anchors." ) # reset running stats @@ -386,6 +394,7 @@ def _anchor_growing( epsilon_g = self.voxel_size new_anchors: torch.tensor + growing_treshold = state["count"] > self.refine_every * self.growing_thresholds # Step 2: Iterate while m <= M while m <= M: n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features @@ -399,6 +408,8 @@ def _anchor_growing( # Step 4: Mask from grad threshold. Select neural gaussians (Select candidates) gradient_mask = grads >= tau + # + gradient_mask = torch.logical_and(gradient_mask, growing_treshold) gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0) neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1) selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask]