Skip to content

Commit

Permalink
add observation threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 20, 2024
1 parent 09ebdfa commit 3423efb
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions gsplat/strategy/scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,31 @@ class ScaffoldStrategy(Strategy):
feat_dim: int = 32
n_feat_offsets: int = 10
refine_every: int = 100

# 3.3 Observation Thresholds (compare paper)
# "To enhance the robustness of the Growing and Pruning operations for long image sequences, ..."
pruning_thresholds: float = 0.8
growing_thresholds: float = 0.4

pause_refine_after_reset: int = 0
absgrad: bool = False
revised_opacity: bool = False
verbose: bool = True

colors_mlp: torch.nn.Sequential = torch.nn.Sequential(
torch.nn.Linear(feat_dim + 3, feat_dim),
torch.nn.ReLU(True),
torch.nn.Linear(feat_dim, 3 * n_feat_offsets),
torch.nn.Sigmoid()
).cuda()

opacities_mlp: torch.nn.Sequential = torch.nn.Sequential(
torch.nn.Linear(feat_dim + 3, feat_dim),
torch.nn.ReLU(True),
torch.nn.Linear(feat_dim, n_feat_offsets),
torch.nn.Tanh()
).cuda()

scale_rot_mlp: torch.nn.Sequential = torch.nn.Sequential(
torch.nn.Linear(feat_dim + 3, feat_dim),
torch.nn.ReLU(True),
Expand Down Expand Up @@ -201,13 +210,12 @@ def step_post_backward(
if self.verbose:
print(
f"Step {step}: {new_anchors} anchors grown."
f"Now having {len(params['anchors'])} GSs."
f"Now having {len(params['anchors'])} anchors."
)

# prune GSs
is_prune = (state["anchor_opacity"] < self.prune_opa * state["anchor_count"]).squeeze()
mask = state["anchor_opacity"] > self.prune_opa
print(mask.sum().item())
is_prune = torch.logical_and(is_prune, state["anchor_count"] > self.refine_every * self.pruning_thresholds)

n_prune = is_prune.sum().item()
if n_prune > 0:
Expand All @@ -216,8 +224,8 @@ def step_post_backward(

if self.verbose:
print(
f"Step {step}: {n_prune} GSs pruned. "
f"Now having {len(params['anchors'])} GSs."
f"Step {step}: {n_prune} anchors pruned. "
f"Now having {len(params['anchors'])} anchors."
)

# reset running stats
Expand Down Expand Up @@ -386,6 +394,7 @@ def _anchor_growing(
epsilon_g = self.voxel_size
new_anchors: torch.tensor

growing_treshold = state["count"] > self.refine_every * self.growing_thresholds
# Step 2: Iterate while m <= M
while m <= M:
n_feature_diff = params["features"].shape[0] * self.n_feat_offsets - n_init_features
Expand All @@ -399,6 +408,8 @@ def _anchor_growing(

# Step 4: Mask from grad threshold. Select neural gaussians (Select candidates)
gradient_mask = grads >= tau
#
gradient_mask = torch.logical_and(gradient_mask, growing_treshold)
gradient_mask = torch.cat([gradient_mask, torch.zeros(n_feature_diff, dtype=torch.bool, device=device)], dim=0)
neural_gaussians = params["anchors"].unsqueeze(dim=1) + params["offsets"] * torch.exp(params["scales"])[:,:3].unsqueeze(dim=1)
selected_neural_gaussians = neural_gaussians.view([-1,3])[gradient_mask]
Expand Down

0 comments on commit 3423efb

Please sign in to comment.