diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 47cc7bf4..3d448c69 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -135,6 +135,7 @@ def get_optimizer(hps, model=None, batch_axis_name=None): nesterov=(hps.optimizer == 'nesterov')) elif hps.optimizer == 'tearfree': sketch_size = hps.opt_hparams.get('sketchy_rank') + dynamic_size = hps.opt_hparams.get('dynamic_sketchy_rank') if sketch_size is not None and sketch_size > 0: opts = tearfree_sketchy.Options( update_freq=hps.opt_hparams['update_preconditioners_freq'], @@ -153,6 +154,25 @@ def get_optimizer(hps, model=None, batch_axis_name=None): 'shampoo_options': None, 'second_order_type': tearfree_second_order.SecondOrderType.SKETCHY, } + elif dynamic_size is not None and dynamic_size > 0: + opts = tearfree_dynamic_sketchy.Options( + update_freq=hps.opt_hparams['update_preconditioners_freq'], + second_moment_decay=hps.opt_hparams['beta2'], + rank=dynamic_size, + epsilon=hps.opt_hparams.get('matrix_epsilon', 1e-16), + relative_epsilon=hps.opt_hparams.get( + 'matrix_relative_epsilon', False + ), + delta=hps.opt_hparams.get('delta', 1e-5), + eta=hps.opt_hparams.get('ekfac_svd', 1e-5), + err_tol=hps.opt_hparams.get('err_tol', 1e-5), + seed=hps.opt_hparams.get('seed', 0), + ) + opts = { + 'dynamic_sketchy_options': opts, + 'shampoo_options': None, + 'second_order_type': tearfree_second_order.SecondOrderType.DYNAMIC, + } else: opts = tearfree_shampoo.Options( second_moment_decay=hps.opt_hparams['beta2'],