Skip to content

Commit

Permalink
Experiments
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560730284
  • Loading branch information
init2winit Team authored and copybara-github committed Aug 28, 2023
1 parent 9ba7f3f commit c71f90f
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,11 +781,13 @@ def update_fn(updates, state, params):

def scale_by_layerwise_adaptive_gd_simple(
init_r_squared: float = 1.0,
eps: float = 1e-8,
) -> optax.GradientTransformation:
"""Rescale updates according to simpler LAYER-WISE Adaptive GD.
Args:
init_r_squared: Initial guess for r^2.
eps: Initial value of mu_sum.
Returns:
An (init_fn, update_fn) tuple.
Expand All @@ -797,7 +799,7 @@ def init_fn(params):
r_squared=jax.tree_map(
lambda x: init_r_squared * jnp.ones([], jnp.float64), params
),
mu_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
mu_sum=jax.tree_map(lambda x: eps * jnp.ones([], jnp.float64), params),
init_params=init_params,
)

Expand Down Expand Up @@ -833,11 +835,13 @@ def update_fn(updates, state, params):

def scale_by_coordinate_wise_adaptive_gd_simple(
init_r_squared: float = 1.0,
eps: float = 1e-8,
) -> optax.GradientTransformation:
"""Rescale updates according to simpler COORDINATE-WISE Adaptive GD.
Args:
init_r_squared: Initial guess for r^2.
eps: Initial value for mu_sum.
Returns:
An (init_fn, update_fn) tuple.
Expand All @@ -846,8 +850,12 @@ def scale_by_coordinate_wise_adaptive_gd_simple(
def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
return ScaleBy_Adaptive_GD_Simple_State(
r_squared=jax.tree_map(init_r_squared * jnp.ones_like, params),
mu_sum=jax.tree_map(jnp.zeros_like, params),
r_squared=jax.tree_map(
lambda x: init_r_squared*jnp.ones_like(x), params
),
mu_sum=jax.tree_map(
lambda x: eps*jnp.ones_like(x), params
),
init_params=init_params,
)

Expand Down

0 comments on commit c71f90f

Please sign in to comment.