From c47c4c52e9cdb67aea19df84a033aa6af2cdbef5 Mon Sep 17 00:00:00 2001 From: Evan Walters Date: Wed, 18 Dec 2024 21:03:05 -0700 Subject: [PATCH] small fix --- src/levanter/optim/kron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/levanter/optim/kron.py b/src/levanter/optim/kron.py index 196cfb527..352f2e29f 100644 --- a/src/levanter/optim/kron.py +++ b/src/levanter/optim/kron.py @@ -497,7 +497,7 @@ def _balance_Q(Q: List[jax.Array]): updates = grads_structure.unflatten(precond_gs) Qs = grads_structure.unflatten(Qs) - precond_gs = updates_struct.unflatten(precond_gs) + updates = updates_struct.unflatten(updates) # dtypes and new state mu = otu.tree_cast(mu, mu_dtype)