diff --git a/examples/simple_trainer.py b/examples/simple_trainer.py index 3e544201..096682c3 100644 --- a/examples/simple_trainer.py +++ b/examples/simple_trainer.py @@ -724,27 +724,6 @@ def train(self): data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt" ) - if isinstance(self.cfg.strategy, DefaultStrategy): - self.cfg.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - packed=cfg.packed, - ) - elif isinstance(self.cfg.strategy, MCMCStrategy): - self.cfg.strategy.step_post_backward( - params=self.splats, - optimizers=self.optimizers, - state=self.strategy_state, - step=step, - info=info, - lr=schedulers[0].get_last_lr()[0], - ) - else: - assert_never(self.cfg.strategy) - # Turn Gradients into Sparse Tensor before running optimizer if cfg.sparse_grad: assert cfg.packed, "Sparse gradients only work with packed mode." @@ -776,6 +755,28 @@ def train(self): for scheduler in schedulers: scheduler.step() + # Run post-backward steps after backward and optimizer + if isinstance(self.cfg.strategy, DefaultStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + packed=cfg.packed, + ) + elif isinstance(self.cfg.strategy, MCMCStrategy): + self.cfg.strategy.step_post_backward( + params=self.splats, + optimizers=self.optimizers, + state=self.strategy_state, + step=step, + info=info, + lr=schedulers[0].get_last_lr()[0], + ) + else: + assert_never(self.cfg.strategy) + # eval the full set if step in [i - 1 for i in cfg.eval_steps]: self.eval(step)