diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 5cad993f..3da5e9c1 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -369,8 +369,8 @@ def get_optimizer(hps, model=None, batch_axis_name=None): optax.contrib.momo_adam )( learning_rate=0.0, - b1=hps.opt_hparams['b1'], - b2=hps.opt_hparams['b2'], + b1=hps.opt_hparams['beta1'], + b2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['eps'], lower_bound=hps.opt_hparams['lower_bound'], weight_decay=weight_decay,