Skip to content

Commit

Permalink
External
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 558441600
  • Loading branch information
init2winit Team authored and copybara-github committed Aug 19, 2023
1 parent bd18fa9 commit de9c1ac
Showing 1 changed file with 289 additions and 0 deletions.
289 changes: 289 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,67 @@ def update_fn(updates, state, params):
return optax.GradientTransformation(init_fn, update_fn)


def scale_by_adaptive_gd_small_r() -> optax.GradientTransformation:
"""Rescale updates according to adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
prev_update = jax.tree_map(
jnp.zeros_like, params
) # previous update with step-size/lr included
return ScaleBy_Adaptive_GD_State(
r_squared=1e-3*jnp.ones([], jnp.float64),
lambda_prev=jnp.zeros([], jnp.float64),
lambda_sum=jnp.zeros([], jnp.float64),
init_params=init_params,
prev_update=prev_update,
)

def update_fn(updates, state, params):
# we can use layer-wise distances later for a layer-wise variant
layer_wise_curr_distance_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_distance_norm_squared = utils.total_tree_sum(
layer_wise_curr_distance_squared
)
# curr_r_squared plays the role of r_t^2 here
curr_r_squared = jnp.maximum(state.r_squared, curr_distance_norm_squared)
new_updates = jax.tree_map(
lambda g, g_prev: g - state.lambda_prev * g_prev,
updates,
state.prev_update,
)
new_update_norm_squared = utils.total_tree_norm_sql2(new_updates)
lambda_new = 0.5 * (
jnp.sqrt(
state.lambda_sum**2
+ jnp.divide(new_update_norm_squared, curr_r_squared)
)
- state.lambda_sum
)
lambda_sum_new = state.lambda_sum + lambda_new
new_updates_with_lr = jax.tree_map(
lambda u: u / lambda_sum_new, new_updates
)
negative_new_updates_with_lr = jax.tree_map(
lambda u: -u, new_updates_with_lr
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_State(
r_squared=curr_r_squared,
lambda_prev=lambda_new,
lambda_sum=lambda_sum_new,
init_params=state.init_params,
prev_update=negative_new_updates_with_lr,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_layerwise_adaptive_gd() -> optax.GradientTransformation:
"""Rescale updates according to LAYER-WISE Adaptive GD.
Expand Down Expand Up @@ -638,6 +699,74 @@ def update_fn(updates, state, params):
return optax.GradientTransformation(init_fn, update_fn)


def scale_by_layerwise_adaptive_gd_small_r() -> optax.GradientTransformation:
"""Rescale updates according to LAYER-WISE Adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
prev_update = jax.tree_map(
jnp.zeros_like, params
) # previous update with step-size/lr included
return ScaleBy_Adaptive_GD_State(
r_squared=jax.tree_map(
lambda x: 1e-12 * jnp.ones([], jnp.float64), params
),
lambda_prev=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
lambda_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
init_params=init_params,
prev_update=prev_update,
)

def update_fn(updates, state, params):
layer_wise_curr_distance_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_distance_norm_squared = layer_wise_curr_distance_squared
# curr_r_squared plays the role of r_t^2 here
curr_r_squared = jax.tree_map(
jnp.maximum,
state.r_squared,
curr_distance_norm_squared,
)
new_updates = jax.tree_map(
lambda g, g_prev, l_prev: g - l_prev * g_prev,
updates,
state.prev_update,
state.lambda_prev,
)
new_update_norm_squared = jax.tree_map(
lambda u: jnp.sum(u ** 2), new_updates
)
lambda_new = jax.tree_map(
lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + jnp.divide(g, r)) - l),
state.lambda_sum,
new_update_norm_squared,
curr_r_squared,
)
lambda_sum_new = jax.tree_map(
lambda l1, l2: l1 + l2, state.lambda_sum, lambda_new
)
new_updates_with_lr = jax.tree_map(
lambda u, l: u / l, new_updates, lambda_sum_new
)
negative_new_updates_with_lr = jax.tree_map(
lambda u: -u, new_updates_with_lr
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_State(
r_squared=curr_r_squared,
lambda_prev=lambda_new,
lambda_sum=lambda_sum_new,
init_params=state.init_params,
prev_update=negative_new_updates_with_lr,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_coordinate_wise_adaptive_gd() -> optax.GradientTransformation:
"""Rescale updates according to COORDINATE-WISE Adaptive GD.
Expand Down Expand Up @@ -705,6 +834,155 @@ def update_fn(updates, state, params):
return optax.GradientTransformation(init_fn, update_fn)


class ScaleBy_Adaptive_GD_Simple_State(NamedTuple):
"""State for the simpler adaptive GD algorithm."""

r_squared: Any
mu_sum: Any
init_params: optax.Updates


def scale_by_adaptive_gd_simple() -> optax.GradientTransformation:
"""Rescale updates according to simpler adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
return ScaleBy_Adaptive_GD_Simple_State(
r_squared=jnp.ones([], jnp.float64),
mu_sum=jnp.zeros([], jnp.float64),
init_params=init_params,
)

def update_fn(updates, state, params):
# we can use layer-wise distances later for a layer-wise variant
layer_wise_curr_distance_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_distance_norm_squared = utils.total_tree_sum(
layer_wise_curr_distance_squared
)
curr_r_squared = jnp.maximum(state.r_squared, curr_distance_norm_squared)
update_norm_squared = utils.total_tree_norm_sql2(updates)
mu_sum_new = 0.5 * (
jnp.sqrt(
state.mu_sum**2
+ jnp.divide((4*update_norm_squared), curr_r_squared)
)
+ state.mu_sum
)
new_updates_with_lr = jax.tree_map(
lambda u: u / mu_sum_new, updates
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State(
r_squared=curr_r_squared,
mu_sum=mu_sum_new,
init_params=state.init_params,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_layerwise_adaptive_gd_simple() -> optax.GradientTransformation:
"""Rescale updates according to simpler LAYER-WISE Adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
return ScaleBy_Adaptive_GD_Simple_State(
r_squared=jax.tree_map(lambda x: jnp.ones([], jnp.float64), params),
mu_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
init_params=init_params,
)

def update_fn(updates, state, params):
curr_distance_norm_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_r_squared = jax.tree_map(
jnp.maximum,
state.r_squared,
curr_distance_norm_squared,
)
update_norm_squared = jax.tree_map(
lambda u: jnp.sum(u ** 2), updates
)
mu_sum_new = jax.tree_map(
lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + 4 * jnp.divide(g, r)) + l),
state.mu_sum,
update_norm_squared,
curr_r_squared,
)
new_updates_with_lr = jax.tree_map(
lambda u, l: u / l, updates, mu_sum_new
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State(
r_squared=curr_r_squared,
mu_sum=mu_sum_new,
init_params=state.init_params,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_coordinate_wise_adaptive_gd_simple() -> (
optax.GradientTransformation
):
"""Rescale updates according to simpler COORDINATE-WISE Adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
return ScaleBy_Adaptive_GD_Simple_State(
# r_squared=jax.tree_map(
# lambda x: jnp.ones_like(x) / jnp.size(x),
# params,
# ),
# trying with r^2=1 for all params for now
r_squared=jax.tree_map(jnp.ones_like, params),
mu_sum=jax.tree_map(jnp.zeros_like, params),
init_params=init_params,
)

def update_fn(updates, state, params):
curr_distance_norm_squared = jax.tree_map(
lambda x_t, x_0: jnp.square(x_t - x_0), params, state.init_params
)
curr_r_squared = jax.tree_map(
jnp.maximum,
state.r_squared,
curr_distance_norm_squared,
)
update_norm_squared = jax.tree_map(
jnp.square, updates
)
mu_sum_new = jax.tree_map(
lambda l, g, r: 0.5*(jnp.sqrt(jnp.square(l) + 4*jnp.divide(g, r)) + l),
state.mu_sum,
update_norm_squared,
curr_r_squared,
)
new_updates_with_lr = jax.tree_map(
jnp.divide, updates, mu_sum_new
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State(
r_squared=curr_r_squared,
mu_sum=mu_sum_new,
init_params=state.init_params,
)

return optax.GradientTransformation(init_fn, update_fn)


# TODO(namanagarwal): Add a test for Nadam
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
Expand Down Expand Up @@ -1489,10 +1767,21 @@ def update_fn(updates, state, params):
# scale_by_rms exists only for backward compatability
_composites = {
'scale_by_adaptive_gd': scale_by_adaptive_gd,
'scale_by_adaptive_gd_small_r': scale_by_adaptive_gd_small_r,
'scale_by_adaptive_gd_simple': scale_by_adaptive_gd_simple,
'scale_by_layerwise_adaptive_gd': scale_by_layerwise_adaptive_gd,
'scale_by_layerwise_adaptive_gd_small_r': (
scale_by_layerwise_adaptive_gd_small_r
),
'scale_by_layerwise_adaptive_gd_simple': (
scale_by_layerwise_adaptive_gd_simple
),
'scale_by_coordinate_wise_adaptive_gd': (
scale_by_coordinate_wise_adaptive_gd
),
'scale_by_coordinate_wise_adaptive_gd_simple': (
scale_by_coordinate_wise_adaptive_gd_simple
),
'scale_by_adam': scale_by_adam,
'scale_by_adam_plus': scale_by_adam_plus,
'scale_by_yogi': optax.scale_by_yogi,
Expand Down

0 comments on commit de9c1ac

Please sign in to comment.