-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incompatible devices for jitted computation #609
Comments
Well that's a remarkably unhelpful error message. I have a strong suspicion it's because of #588 and it's now using an incompatible mesh. Can you try reverting that commit and seeing if that fixes? If it fixes, it's a pretty easy patch |
Indeed. I just checked out 2bb1252 and it trains nicely, but I had to reduce parallelism to 8 to make it fit into the v4-64, strange. |
Ah interesting, I hit this same error when trying to use llama2-lora. I found if I disabled the "best-effort sharding" in hf_checkpoints I could get training to work: levanter/src/levanter/compat/hf_checkpoints.py Line 1072 in 9adba01
|
ok i think can be fixed with a long-standing thing I've been meaning to fix in Haliax (but isn't quite compatible with some of the CPU-only testing I do) |
can you try #622, which is a simpler fix than the haliax thing |
On a new TPUv4-64, I started the training of Mistral model, data tokenized properly but right at the end of model loading, I got this strange error I've never seen before. Any clue will be appreciated:
The model config:
And the trainer config
I tried with both Splash and JAX FA attention with the same result. The same config with half the value of
per_device_parallelism
andper_device_eval_parallelism
works fine on a TPUv4-32.The text was updated successfully, but these errors were encountered: