diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py index 196cfb527..352f2e29f 100644 --- a/src/levanter/optim/kron.py +++ b/src/levanter/optim/kron.py @@ -497,7 +497,7 @@ def _balance_Q(Q: List[jax.Array]): updates = grads_structure.unflatten(precond_gs) Qs = grads_structure.unflatten(Qs) - precond_gs = updates_struct.unflatten(precond_gs) + updates = updates_struct.unflatten(updates) # dtypes and new state mu = otu.tree_cast(mu, mu_dtype)