diff --git a/src/levanter/optim/config.py b/src/levanter/optim/config.py index d1f4841fc..d814a6b64 100644 --- a/src/levanter/optim/config.py +++ b/src/levanter/optim/config.py @@ -180,7 +180,7 @@ def lr_scheduler(self, num_train_steps): if stable_steps != 0: stable = optax.constant_schedule(self.learning_rate) schedules.append(stable) - boundaries.append(start + stable_steps) + boundaries.append(start + warmup_steps + stable_steps) match self.lr_schedule: case "constant":