Getting NaN for gradients when computing the loss with jit compile #400
Replies: 1 comment 1 reply
-
I was able to solve the issue. It seems that the problem was in how I was initializing the initial state with Brax. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I am fairly new to JAX and Brax, and I am running into a issue that I don't quite understand. I am trying to get the sensitivity of the loss with respect to the model params by running the following minimum reproducible example.
When I run this with Jax jit compile I get NaN for the gradients, but when I implement config.update("jax_disable_jit", True) to disable the jit compile I am getting values for gradients, but it is very slow. Can anyone help me shed some light as to why this is happening, and how I can compute gradients with the jit compiler? I am currently raising this issue with the authors of JAX as well. I also brought this up with the authors of equinox (https://github.com/patrick-kidger/equinox/issues/523), and they instructed me to raise the issue here.
Beta Was this translation helpful? Give feedback.
All reactions