diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py index e24f1833b..196cfb527 100644 --- a/src/levanter/optim/kron.py +++ b/src/levanter/optim/kron.py @@ -133,10 +133,9 @@ def _optimizer(learning_rate) -> optax.GradientTransformation: components.append(optax.scale_by_learning_rate(learning_rate)) return optax.chain(*components) - # return optax.inject_hyperparams(_optimizer)( - # learning_rate=self.lr_scheduler(num_train_steps) - # ) - return _optimizer(self.lr_scheduler(num_train_steps)) + return optax.inject_hyperparams(_optimizer)( + learning_rate=self.lr_scheduler(num_train_steps) + ) from typing import Any, List, Optional, Union, Callable