diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 55d45fa2..c82c2fc1 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -157,6 +157,10 @@ def get_optimizer(hps, model=None, batch_axis_name=None): .opt_hparams['start_preconditioning_step'], preconditioning_compute_steps=hps .opt_hparams['preconditioning_compute_steps'], + decay_preconditioning_compute_steps=hps + .opt_hparams.get('decay_preconditioning_compute_steps', False), + end_preconditioning_compute_steps=hps + .opt_hparams.get('end_preconditioning_compute_steps', None), statistics_compute_steps=statistics_compute_steps, best_effort_shape_interpretation=hps .opt_hparams['best_effort_shape_interpretation'],