Skip to content

Commit

Permalink
Expts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556949929
  • Loading branch information
init2winit Team authored and copybara-github committed Aug 15, 2023
1 parent f14cc91 commit b7d0dcd
Show file tree
Hide file tree
Showing 2 changed files with 231 additions and 0 deletions.
212 changes: 212 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
# limitations under the License.

"""Transforms."""

import functools
from typing import Any, List, NamedTuple, Optional

import chex
from init2winit.optimizer_lib.kitchen_sink._src import utils
import jax
import jax.numpy as jnp
import optax


# pylint:disable=invalid-name
# pylint:disable=no-value-for-parameter

Expand Down Expand Up @@ -477,6 +480,210 @@ def update_fn(updates, state, params=None):
return optax.GradientTransformation(init_fn, update_fn)


class ScaleBy_Adaptive_GD_State(NamedTuple):
"""State for the adaptive GD algorithm."""

r_squared: Any
lambda_prev: Any
lambda_sum: Any
init_params: optax.Updates
prev_update: optax.Updates


def scale_by_adaptive_gd() -> optax.GradientTransformation:
"""Rescale updates according to adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
prev_update = jax.tree_map(
jnp.zeros_like, params
) # previous update with step-size/lr included
return ScaleBy_Adaptive_GD_State(
r_squared=jnp.ones([], jnp.float64),
lambda_prev=jnp.zeros([], jnp.float64),
lambda_sum=jnp.zeros([], jnp.float64),
init_params=init_params,
prev_update=prev_update,
)

def update_fn(updates, state, params):
# we can use layer-wise distances later for a layer-wise variant
layer_wise_curr_distance_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_distance_norm_squared = utils.total_tree_sum(
layer_wise_curr_distance_squared
)
# curr_r_squared plays the role of r_t^2 here
curr_r_squared = jnp.maximum(state.r_squared, curr_distance_norm_squared)
new_updates = jax.tree_map(
lambda g, g_prev: g - state.lambda_prev * g_prev,
updates,
state.prev_update,
)
new_update_norm_squared = utils.total_tree_norm_sql2(new_updates)
lambda_new = 0.5 * (
jnp.sqrt(
state.lambda_sum**2
+ jnp.divide(new_update_norm_squared, curr_r_squared)
)
- state.lambda_sum
)
lambda_sum_new = state.lambda_sum + lambda_new
new_updates_with_lr = jax.tree_map(
lambda u: u / lambda_sum_new, new_updates
)
negative_new_updates_with_lr = jax.tree_map(
lambda u: -u, new_updates_with_lr
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_State(
r_squared=curr_r_squared,
lambda_prev=lambda_new,
lambda_sum=lambda_sum_new,
init_params=state.init_params,
prev_update=negative_new_updates_with_lr,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_layerwise_adaptive_gd() -> optax.GradientTransformation:
"""Rescale updates according to LAYER-WISE Adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
prev_update = jax.tree_map(
jnp.zeros_like, params
) # previous update with step-size/lr included
return ScaleBy_Adaptive_GD_State(
r_squared=jax.tree_map(lambda x: jnp.ones([], jnp.float64), params),
lambda_prev=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
lambda_sum=jax.tree_map(lambda x: jnp.zeros([], jnp.float64), params),
init_params=init_params,
prev_update=prev_update,
)

def update_fn(updates, state, params):
layer_wise_curr_distance_squared = jax.tree_map(
lambda x_t, x_0: jnp.sum((x_t - x_0) ** 2), params, state.init_params
)
curr_distance_norm_squared = layer_wise_curr_distance_squared
# curr_r_squared plays the role of r_t^2 here
curr_r_squared = jax.tree_map(
jnp.maximum,
state.r_squared,
curr_distance_norm_squared,
)
new_updates = jax.tree_map(
lambda g, g_prev, l_prev: g - l_prev * g_prev,
updates,
state.prev_update,
state.lambda_prev,
)
new_update_norm_squared = jax.tree_map(
lambda u: jnp.sum(u ** 2), new_updates
)
lambda_new = jax.tree_map(
lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + jnp.divide(g, r)) - l),
state.lambda_sum,
new_update_norm_squared,
curr_r_squared,
)
lambda_sum_new = jax.tree_map(
lambda l1, l2: l1 + l2, state.lambda_sum, lambda_new
)
new_updates_with_lr = jax.tree_map(
lambda u, l: u / l, new_updates, lambda_sum_new
)
negative_new_updates_with_lr = jax.tree_map(
lambda u: -u, new_updates_with_lr
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_State(
r_squared=curr_r_squared,
lambda_prev=lambda_new,
lambda_sum=lambda_sum_new,
init_params=state.init_params,
prev_update=negative_new_updates_with_lr,
)

return optax.GradientTransformation(init_fn, update_fn)


def scale_by_coordinate_wise_adaptive_gd() -> optax.GradientTransformation:
"""Rescale updates according to COORDINATE-WISE Adaptive GD.
Returns:
An (init_fn, update_fn) tuple.
"""

def init_fn(params):
init_params = jax.tree_map(jnp.copy, params) # x0
prev_update = jax.tree_map(
jnp.zeros_like, params
) # previous update with step-size/lr included
return ScaleBy_Adaptive_GD_State(
r_squared=jax.tree_map(
lambda x: jnp.ones_like(x) / jnp.size(x),
params,
),
lambda_prev=jax.tree_map(jnp.zeros_like, params),
lambda_sum=jax.tree_map(jnp.zeros_like, params),
init_params=init_params,
prev_update=prev_update,
)

def update_fn(updates, state, params):
curr_distance_norm_squared = jax.tree_map(
lambda x_t, x_0: jnp.square(x_t - x_0), params, state.init_params
)
curr_r_squared = jax.tree_map(
jnp.maximum,
state.r_squared,
curr_distance_norm_squared,
)
new_updates = jax.tree_map(
lambda g, g_prev, l_prev: g - jnp.multiply(l_prev, g_prev),
updates,
state.prev_update,
state.lambda_prev,
)
new_update_norm_squared = jax.tree_map(
jnp.square, new_updates
)
lambda_new = jax.tree_map(
lambda l, g, r: 0.5 * (jnp.sqrt(jnp.square(l) + jnp.divide(g, r)) - l),
state.lambda_sum,
new_update_norm_squared,
curr_r_squared,
)
lambda_sum_new = jax.tree_map(
lambda l1, l2: l1 + l2, state.lambda_sum, lambda_new
)
new_updates_with_lr = jax.tree_map(
jnp.divide, new_updates, lambda_sum_new
)
negative_new_updates_with_lr = jax.tree_map(
lambda u: -u, new_updates_with_lr
)
return new_updates_with_lr, ScaleBy_Adaptive_GD_State(
r_squared=curr_r_squared,
lambda_prev=lambda_new,
lambda_sum=lambda_sum_new,
init_params=state.init_params,
prev_update=negative_new_updates_with_lr,
)

return optax.GradientTransformation(init_fn, update_fn)


# TODO(namanagarwal): Add a test for Nadam
class ScaleByAdamState(NamedTuple):
"""State for the NAdam algorithm."""
Expand Down Expand Up @@ -1209,6 +1416,11 @@ def update_fn(updates, state, params):

# scale_by_rms exists only for backward compatability
_composites = {
'scale_by_adaptive_gd': scale_by_adaptive_gd,
'scale_by_layerwise_adaptive_gd': scale_by_layerwise_adaptive_gd,
'scale_by_coordinate_wise_adaptive_gd': (
scale_by_coordinate_wise_adaptive_gd
),
'scale_by_adam': scale_by_adam,
'scale_by_adam_plus': scale_by_adam_plus,
'scale_by_yogi': optax.scale_by_yogi,
Expand Down
19 changes: 19 additions & 0 deletions init2winit/optimizer_lib/kitchen_sink/_src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,32 @@
"""Optimizer utilities."""

import copy
import operator

from absl import logging
import flax
import jax
import jax.numpy as jnp
import optax


def total_tree_sum(pytree):
"""Compute the overall sum of a pytree."""
sums = jax.tree_map(jnp.sum, pytree)
return jax.tree_util.tree_reduce(operator.add, sums, 0)


def tree_norm_sql2(pytree):
"""Compute the param-wise squared L2 norm of a pytree."""
return jax.tree_map(lambda x: jnp.linalg.norm(x.reshape(-1)) ** 2, pytree)


def total_tree_norm_sql2(pytree):
"""Compute the overall squared L2 norm of a pytree."""
sql2_norms = tree_norm_sql2(pytree)
return jax.tree_util.tree_reduce(operator.add, sql2_norms, 0)


def is_leaf(x):
return isinstance(x, dict) and 'element' in x

Expand Down

0 comments on commit b7d0dcd

Please sign in to comment.