Skip to content

Commit

Permalink
better params
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 20, 2024
1 parent 1c101f7 commit 09ebdfa
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions examples/simple_trainer_scaffold.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,12 @@ def create_splats_with_optimizers(

params = [
# name, value, lr
("anchors", torch.nn.Parameter(points), 1.6e-4 * scene_scale),
("scales", torch.nn.Parameter(scales), 5e-3),
("quats", torch.nn.Parameter(quats), 1e-3),
("opacities_mlp", strategy.opacities_mlp.parameters(), 0.02),
("features", torch.nn.Parameter(features), 1.6e-4 * scene_scale),
("offsets", torch.nn.Parameter(offsets), 0.004),
("anchors", torch.nn.Parameter(points), 0),
("scales", torch.nn.Parameter(scales), 0.007),
("quats", torch.nn.Parameter(quats), 0.002),
("opacities_mlp", strategy.opacities_mlp.parameters(), 0.002),
("features", torch.nn.Parameter(features), 0.0075 * scene_scale),
("offsets", torch.nn.Parameter(offsets), 0.01),
("colors_mlp", strategy.colors_mlp.parameters(), 0.008),
("scale_rot_mlp", strategy.scale_rot_mlp.parameters(), 0.004),
]
Expand Down Expand Up @@ -554,9 +554,14 @@ def train(self):
init_step = 0

schedulers = [
# anchors has a learning rate schedule, that end at 0.01 of the initial value
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["anchors"], gamma=0.01 ** (1.0 / max_steps)
self.optimizers["offsets"], gamma=0.01 ** (1.0 / max_steps)
),
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["opacities_mlp"], gamma=0.002 ** (1.0 / max_steps)
),
torch.optim.lr_scheduler.ExponentialLR(
self.optimizers["colors_mlp"], gamma=0.008 ** (1.0 / max_steps)
),
]
if cfg.pose_opt:
Expand Down Expand Up @@ -697,15 +702,6 @@ def train(self):
tvloss = 10 * total_variation_loss(self.bil_grids.grids)
loss += tvloss

# regularizations
# not gonna work. Check this
# if cfg.opacity_reg > 0.0:
# loss = (
# loss
# + cfg.opacity_reg
# * torch.abs(torch.sigmoid(self.splats["opacities"])).mean()
# )

loss.backward()

desc = f"loss={loss.item():.3f}| "
Expand Down Expand Up @@ -805,9 +801,9 @@ def train(self):
scheduler.step()

# eval the full set
if step in [i - 1 for i in cfg.eval_steps]:
self.eval(step)
self.render_traj(step)
# if step in [i - 1 for i in cfg.eval_steps]:
# self.eval(step)
# self.render_traj(step)

# run compression
if cfg.compression is not None and step in [i - 1 for i in cfg.eval_steps]:
Expand Down

0 comments on commit 09ebdfa

Please sign in to comment.