You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.
The gradient explosion seemed to be coming from local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.
Changing
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
to DEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.
Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.
The text was updated successfully, but these errors were encountered:
I was training a llama model on GPU, with a custom embedding. It worked fine with 12 layers, dim 1024, seq length 256, but loss would become nan after the first step if setting num_layers to more than 17. I debugged the gradients, and found after each layer their magnitude would increase by around 100x, until they hit float32_max at around the 18th layer and became inf, leading to nan loss.
The gradient explosion seemed to be coming from
local_exps = jnp.exp(attn_weights - local_max)
in attentions.py.
Changing
DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max)
to
DEFAULT_MASK_VALUE = -jnp.inf
fixed the issue, and the gradients' magnitude stopped increasing after each level.
Presumably the issue wasn't noticed during TPU training as that uses a separate codepath.
The text was updated successfully, but these errors were encountered: