diff --git a/src/levanter/optim/schedulefree_sophia.py b/src/levanter/optim/schedulefree_sophia.py index 8b151e942..f371e690a 100644 --- a/src/levanter/optim/schedulefree_sophia.py +++ b/src/levanter/optim/schedulefree_sophia.py @@ -256,7 +256,7 @@ def _do_update(): # EMAs of hessian nu = update_moment(new_hess, state.h, b2, 1) return ScaleBySophiaState( - count=state.count, hessian_count=state.hessian_count + 1, mu=state.mu, h=nu, hess_key=next_key + count=state.count, hessian_count=state.hessian_count + 1, z=state.z, h=nu, hess_key=next_key ) def _dont_update():