diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py index dd2c062e..99c4ddac 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py @@ -572,6 +572,67 @@ def update_fn(updates, state, params): return optax.GradientTransformation(init_fn, update_fn) +def scale_by_adaptive_gd_small_r() -> optax.GradientTransformation: + """Rescale updates according to adaptive GD. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + init_params = jax.tree_map(jnp.copy, params) # x0 + prev_update = jax.tree_map( + jnp.zeros_like, params + ) # previous update with step-size/lr included + return ScaleBy_Adaptive_GD_State( + r_squared=1e-3*jnp.ones([], jnp.float64), + lambda_prev=jnp.zeros([], jnp.float64), + lambda_sum=jnp.zeros([], jnp.float64), + init_params=init_params, + prev_update=prev_update, + ) + + def update_fn(updates, state, params): + # we can use layer-wise distances later for a layer-wise variant + layer_wise_curr_distance_squared = jax.tree_map( + lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params + ) + curr_distance_norm_squared = utils.total_tree_sum( + layer_wise_curr_distance_squared + ) + # curr_r_squared plays the role of r_t^2 here + curr_r_squared = jnp.maximum(state.r_squared, curr_distance_norm_squared) + new_updates = jax.tree_map( + lambda g, g_prev: g - state.lambda_prev * g_prev, + updates, + state.prev_update, + ) + new_update_norm_squared = utils.total_tree_norm_sql2(new_updates) + lambda_new = 0.5 * ( + jnp.sqrt( + state.lambda_sum**2 + + jnp.divide(new_update_norm_squared, curr_r_squared) + ) + - state.lambda_sum + ) + lambda_sum_new = state.lambda_sum + lambda_new + new_updates_with_lr = jax.tree_map( + lambda u: u / lambda_sum_new, new_updates + ) + negative_new_updates_with_lr = jax.tree_map( + lambda u: -u, new_updates_with_lr + ) + return new_updates_with_lr, ScaleBy_Adaptive_GD_State( + r_squared=curr_r_squared, + lambda_prev=lambda_new, + lambda_sum=lambda_sum_new, + init_params=state.init_params, + prev_update=negative_new_updates_with_lr, + ) + + return optax.GradientTransformation(init_fn, update_fn) + + def scale_by_layerwise_adaptive_gd() -> optax.GradientTransformation: """Rescale updates according to LAYER-WISE Adaptive GD. @@ -638,6 +699,74 @@ def update_fn(updates, state, params): return optax.GradientTransformation(init_fn, update_fn) +def scale_by_layerwise_adaptive_gd_small_r() -> optax.GradientTransformation: + """Rescale updates according to LAYER-WISE Adaptive GD. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + init_params = jax.tree_map(jnp.copy, params) # x0 + prev_update = jax.tree_map( + jnp.zeros_like, params + ) # previous update with step-size/lr included + return ScaleBy_Adaptive_GD_State( + r_squared=jax.tree_map( + lambda x: 1e-12 * jnp.ones([], jnp.float64), params + ), + lambda_prev=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params), + lambda_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params), + init_params=init_params, + prev_update=prev_update, + ) + + def update_fn(updates, state, params): + layer_wise_curr_distance_squared = jax.tree_map( + lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params + ) + curr_distance_norm_squared = layer_wise_curr_distance_squared + # curr_r_squared plays the role of r_t^2 here + curr_r_squared = jax.tree_map( + jnp.maximum, + state.r_squared, + curr_distance_norm_squared, + ) + new_updates = jax.tree_map( + lambda g, g_prev, l_prev: g - l_prev * g_prev, + updates, + state.prev_update, + state.lambda_prev, + ) + new_update_norm_squared = jax.tree_map( + lambda u: jnp.sum(u ** 2), new_updates + ) + lambda_new = jax.tree_map( + lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + jnp.divide(g, r)) - l), + state.lambda_sum, + new_update_norm_squared, + curr_r_squared, + ) + lambda_sum_new = jax.tree_map( + lambda l1, l2: l1 + l2, state.lambda_sum, lambda_new + ) + new_updates_with_lr = jax.tree_map( + lambda u, l: u / l, new_updates, lambda_sum_new + ) + negative_new_updates_with_lr = jax.tree_map( + lambda u: -u, new_updates_with_lr + ) + return new_updates_with_lr, ScaleBy_Adaptive_GD_State( + r_squared=curr_r_squared, + lambda_prev=lambda_new, + lambda_sum=lambda_sum_new, + init_params=state.init_params, + prev_update=negative_new_updates_with_lr, + ) + + return optax.GradientTransformation(init_fn, update_fn) + + def scale_by_coordinate_wise_adaptive_gd() -> optax.GradientTransformation: """Rescale updates according to COORDINATE-WISE Adaptive GD. @@ -705,6 +834,155 @@ def update_fn(updates, state, params): return optax.GradientTransformation(init_fn, update_fn) +class ScaleBy_Adaptive_GD_Simple_State(NamedTuple): + """State for the simpler adaptive GD algorithm.""" + + r_squared: Any + mu_sum: Any + init_params: optax.Updates + + +def scale_by_adaptive_gd_simple() -> optax.GradientTransformation: + """Rescale updates according to simpler adaptive GD. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + init_params = jax.tree_map(jnp.copy, params) # x0 + return ScaleBy_Adaptive_GD_Simple_State( + r_squared=jnp.ones([], jnp.float64), + mu_sum=jnp.zeros([], jnp.float64), + init_params=init_params, + ) + + def update_fn(updates, state, params): + # we can use layer-wise distances later for a layer-wise variant + layer_wise_curr_distance_squared = jax.tree_map( + lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params + ) + curr_distance_norm_squared = utils.total_tree_sum( + layer_wise_curr_distance_squared + ) + curr_r_squared = jnp.maximum(state.r_squared, curr_distance_norm_squared) + update_norm_squared = utils.total_tree_norm_sql2(updates) + mu_sum_new = 0.5 * ( + jnp.sqrt( + state.mu_sum**2 + + jnp.divide((4*update_norm_squared), curr_r_squared) + ) + + state.mu_sum + ) + new_updates_with_lr = jax.tree_map( + lambda u: u / mu_sum_new, updates + ) + return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( + r_squared=curr_r_squared, + mu_sum=mu_sum_new, + init_params=state.init_params, + ) + + return optax.GradientTransformation(init_fn, update_fn) + + +def scale_by_layerwise_adaptive_gd_simple() -> optax.GradientTransformation: + """Rescale updates according to simpler LAYER-WISE Adaptive GD. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + init_params = jax.tree_map(jnp.copy, params) # x0 + return ScaleBy_Adaptive_GD_Simple_State( + r_squared=jax.tree_map(lambda x: jnp.ones([], jnp.float64), params), + mu_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params), + init_params=init_params, + ) + + def update_fn(updates, state, params): + curr_distance_norm_squared = jax.tree_map( + lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params + ) + curr_r_squared = jax.tree_map( + jnp.maximum, + state.r_squared, + curr_distance_norm_squared, + ) + update_norm_squared = jax.tree_map( + lambda u: jnp.sum(u ** 2), updates + ) + mu_sum_new = jax.tree_map( + lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + 4 * jnp.divide(g, r)) + l), + state.mu_sum, + update_norm_squared, + curr_r_squared, + ) + new_updates_with_lr = jax.tree_map( + lambda u, l: u / l, updates, mu_sum_new + ) + return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( + r_squared=curr_r_squared, + mu_sum=mu_sum_new, + init_params=state.init_params, + ) + + return optax.GradientTransformation(init_fn, update_fn) + + +def scale_by_coordinate_wise_adaptive_gd_simple() -> ( + optax.GradientTransformation +): + """Rescale updates according to simpler COORDINATE-WISE Adaptive GD. + + Returns: + An (init_fn, update_fn) tuple. + """ + + def init_fn(params): + init_params = jax.tree_map(jnp.copy, params) # x0 + return ScaleBy_Adaptive_GD_Simple_State( + # r_squared=jax.tree_map( + # lambda x: jnp.ones_like(x) / jnp.size(x), + # params, + # ), + # trying with r^2=1 for all params for now + r_squared=jax.tree_map(jnp.ones_like, params), + mu_sum=jax.tree_map(jnp.zeros_like, params), + init_params=init_params, + ) + + def update_fn(updates, state, params): + curr_distance_norm_squared = jax.tree_map( + lambda x_t, x_0: jnp.square(x_t - x_0), params, state.init_params + ) + curr_r_squared = jax.tree_map( + jnp.maximum, + state.r_squared, + curr_distance_norm_squared, + ) + update_norm_squared = jax.tree_map( + jnp.square, updates + ) + mu_sum_new = jax.tree_map( + lambda l, g, r: 0.5*(jnp.sqrt(jnp.square(l) + 4*jnp.divide(g, r)) + l), + state.mu_sum, + update_norm_squared, + curr_r_squared, + ) + new_updates_with_lr = jax.tree_map( + jnp.divide, updates, mu_sum_new + ) + return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( + r_squared=curr_r_squared, + mu_sum=mu_sum_new, + init_params=state.init_params, + ) + + return optax.GradientTransformation(init_fn, update_fn) + + # TODO(namanagarwal): Add a test for Nadam class ScaleByAdamState(NamedTuple): """State for the NAdam algorithm.""" @@ -1489,10 +1767,21 @@ def update_fn(updates, state, params): # scale_by_rms exists only for backward compatability _composites = { 'scale_by_adaptive_gd': scale_by_adaptive_gd, + 'scale_by_adaptive_gd_small_r': scale_by_adaptive_gd_small_r, + 'scale_by_adaptive_gd_simple': scale_by_adaptive_gd_simple, 'scale_by_layerwise_adaptive_gd': scale_by_layerwise_adaptive_gd, + 'scale_by_layerwise_adaptive_gd_small_r': ( + scale_by_layerwise_adaptive_gd_small_r + ), + 'scale_by_layerwise_adaptive_gd_simple': ( + scale_by_layerwise_adaptive_gd_simple + ), 'scale_by_coordinate_wise_adaptive_gd': ( scale_by_coordinate_wise_adaptive_gd ), + 'scale_by_coordinate_wise_adaptive_gd_simple': ( + scale_by_coordinate_wise_adaptive_gd_simple + ), 'scale_by_adam': scale_by_adam, 'scale_by_adam_plus': scale_by_adam_plus, 'scale_by_yogi': optax.scale_by_yogi,