Skip to content

Commit

Permalink
correct t
Browse files Browse the repository at this point in the history
  • Loading branch information
blahBlahhhJ committed Apr 29, 2024
1 parent 173454d commit ff73972
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/levanter/optim/schedulefree_sophia.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs):
if mu_dtype is not None:
z = jax.tree_util.tree_util.tree_map(lambda t: t.astype(mu_dtype), z)

# update t
t += 1
# update y
new_y = jax.tree_util.tree_map(
lambda y, z, u: (1 - 1 / t) * y + (1 / t) * z + learning_rate * (b1 * (1 - 1 / t) - 1) * u,
Expand All @@ -241,7 +243,7 @@ def update_fn(updates, state, params=None, *, obj_fn, **kwargs):
updates = jax.tree_map(lambda new_y, y: new_y - y, new_y, params)

state = ScaleBySophiaState(
count=t + 1, hessian_count=state.hessian_count, z=new_z, h=h_hat, hess_key=state.hess_key
count=t, hessian_count=state.hessian_count, z=new_z, h=h_hat, hess_key=state.hess_key
)
state = update_hessian(state, params, obj_fn=obj_fn, **kwargs)
return updates, state
Expand Down

0 comments on commit ff73972

Please sign in to comment.