Skip to content

Commit

Permalink
add rademacher random option to sophiah
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Apr 24, 2024
1 parent ad99c98 commit 0dc998a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/levanter/optim/sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import levanter.tracker
from levanter.optim.config import HessianOptConfig, OptimizerConfig
from levanter.optim.util import hvp, tree_gaussian_like
from levanter.optim.util import hvp, tree_gaussian_like, tree_rademacher_like
from levanter.utils.jax_utils import parameter_count, tree_filter_like


Expand Down Expand Up @@ -199,9 +199,10 @@ def _optimizer(learning_rate, gamma) -> optax.GradientTransformation:
@dataclass
class SophiaHConfig(BaseSophiaConfig):
gamma: float = GAMMA_SOPHIA_H
rand: str = "gaussian"

def compute_hessian(self, fn, model, *batch, hess_key: PRNGKey, **batch_kwargs):
return stochastic_hessian_diagonal(fn, model, *batch, **batch_kwargs, hess_key=hess_key)
return stochastic_hessian_diagonal(fn, model, *batch, **batch_kwargs, hess_key=hess_key, rand=self.rand)


def sophia_h(
Expand Down Expand Up @@ -423,7 +424,7 @@ def stochastic_diag_gauss_newton(fn, model, *args, hess_key: PRNGKey, **kwargs):


# Use this for Sophia-H
def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs):
def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, rand: str, **kwargs):
"""Compute the diagonal of the Hessian of a function using a normal distribution.
https://arxiv.org/pdf/2305.14342.pdf Algorithm 1
Expand All @@ -436,7 +437,10 @@ def stochastic_hessian_diagonal(fn, model, *args, hess_key: PRNGKey, **kwargs):
# cf https://arxiv.org/pdf/2006.00719.pdf eqn 9
# https://www-users.cse.umn.edu/~saad/PDF/umsi-2005-082.pdf
# https://arxiv.org/pdf/2208.03268.pdf
g = tree_gaussian_like(hess_key, model)
if rand == "rademacher":
g = tree_rademacher_like(hess_key, model)
else:
g = tree_gaussian_like(hess_key, model)
# TODO: consider allowing for n > 1 gaussians?
product = hvp(lambda m: fn(m, *args, **kwargs), model, g)
hessian = jax.tree_util.tree_map(lambda grad, gaussian: grad * gaussian, product, g)
Expand Down
13 changes: 13 additions & 0 deletions src/levanter/optim/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,16 @@ def tree_gaussian_like(key, tree):
g = jax.tree_util.tree_unflatten(structure, g)

return g

def tree_rademacher_like(key, tree):
leaves, structure = jax.tree_util.tree_flatten(tree)
keys = jax.random.split(key, len(leaves))
# paper uses normal but we use rademacher
# see https://www.ethanepperly.com/index.php/2024/01/28/dont-use-gaussians-in-stochastic-trace-estimation/
g = jax.tree_util.tree_map(
lambda key, x: jax.random.rademacher(key, x.shape, dtype=jnp.float32),
list(keys),
leaves,
)
g = jax.tree_util.tree_unflatten(structure, g)
return g

0 comments on commit 0dc998a

Please sign in to comment.