diff --git a/src/levanter/optim/util.py b/src/levanter/optim/util.py index f73638740..b1456ec07 100644 --- a/src/levanter/optim/util.py +++ b/src/levanter/optim/util.py @@ -1,5 +1,6 @@ import equinox as eqx import jax +import jax.numpy as jnp from levanter.utils.jax_utils import is_inexact_arrayish