Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incompatible devices for jitted computation #609

Closed
versae opened this issue Jun 1, 2024 · 5 comments · Fixed by #622
Closed

Incompatible devices for jitted computation #609

versae opened this issue Jun 1, 2024 · 5 comments · Fixed by #622

Comments

@versae
Copy link
Contributor

versae commented Jun 1, 2024

On a new TPUv4-64, I started the training of Mistral model, data tokenized properly but right at the end of model loading, I got this strange error I've never seen before. Any clue will be appreciated:

2024-06-01T19:39:55 - 0 - __main__ - train_lm.py:124 - INFO :: No training checkpoint found. Initializing model from HF checkpoint 'mimir-project/mimir-mistral-7b-base-scratch'
/home/ubuntu/venv310/lib/python3.10/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
  File "/home/ubuntu/levanter/src/levanter/main/train_lm.py", line 203, in <module>
    levanter.config.main(main)()
  File "/home/ubuntu/levanter/src/levanter/config.py", line 84, in wrapper_inner
    response = fn(cfg, *args, **kwargs)
  File "/home/ubuntu/levanter/src/levanter/main/train_lm.py", line 131, in main
    model = converter.load_pretrained(
  File "/home/ubuntu/levanter/src/levanter/compat/hf_checkpoints.py", line 587, in load_pretrained
    lev_model = load_from_state_dict(state_dict)
  File "/home/ubuntu/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 255, in __call__
    return self._call(False, *args, **kwargs)
  File "/home/ubuntu/venv310/lib/python3.10/site-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/ubuntu/venv310/lib/python3.10/site-packages/haliax/partitioning.py", line 333, in _call
    out, out_static = cached_pjitted_fun(dynamic_donated, dynamic_reserved, static)
ValueError: Received incompatible devices for jitted computation. Got argument dynamic_donated[1][0][0]['lm_head.weight'] of HFCheckpointConverter.load_pretrained.<locals>.load_from_state_dict with shape bfloat16[32768,4096] and device ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] on platform TPU and explicit output sharding with device ids [0, 8, 16, 24, 2, 10, 18, 26, 4, 12, 20, 28, 6, 14, 22, 30, 1, 9, 17, 25, 3, 11, 19, 27, 5, 13, 21, 29, 7, 15, 23, 31] on platform TPU

The model config:

model:
  type: mistral

  num_layers: 32
  num_heads: 32
  hidden_dim: 4096
  intermediate_dim: 14336
  seq_len: 2048
  num_kv_heads: 8
  sliding_window: 2048

  activation_function: silu
  initializer_range: 0.02
  layer_norm_epsilon: 1e-05
  
  # upcast_attn: false
  use_flash_attention: true
  attn_backend: SPLASH
  # flash_attention_block_size: null
  # gradient_checkpointing: true
  # gradient_checkpointing_block_size: 5
  scan_layers: true
  # use_bias: false
  # rope_scaling: null
initialize_from_hf: "mimir-project/mimir-mistral-7b-base-scratch"
use_hf_model_config: false

And the trainer config

trainer:
  mp: p=f32,c=bfloat16
  train_batch_size: 2048
  per_device_parallelism: 32
  per_device_eval_parallelism: 32
  num_train_steps: 10000
  steps_per_eval: 1000
  tensor_parallel_axes: ["mlp", "heads"]
  fsdp_axis: "embed"
  batch_axis: "batch"

optimizer:
  lr_schedule: cosine
  learning_rate: 3e-5
  beta1: 0.9
  beta2: 0.95
  epsilon: 1e-8
  weight_decay: 0.1
  warmup: 0.02
  min_lr_ratio: 0.1

I tried with both Splash and JAX FA attention with the same result. The same config with half the value of per_device_parallelism and per_device_eval_parallelism works fine on a TPUv4-32.

@dlwh
Copy link
Member

dlwh commented Jun 1, 2024

Well that's a remarkably unhelpful error message. I have a strong suspicion it's because of #588 and it's now using an incompatible mesh. Can you try reverting that commit and seeing if that fixes? If it fixes, it's a pretty easy patch

@versae
Copy link
Contributor Author

versae commented Jun 1, 2024

Indeed. I just checked out 2bb1252 and it trains nicely, but I had to reduce parallelism to 8 to make it fit into the v4-64, strange.

@rjpower
Copy link
Collaborator

rjpower commented Jun 4, 2024

Ah interesting, I hit this same error when trying to use llama2-lora. I found if I disabled the "best-effort sharding" in hf_checkpoints I could get training to work:

if jax.device_count() > 1:

@dlwh
Copy link
Member

dlwh commented Jun 5, 2024

ok i think can be fixed with a long-standing thing I've been meaning to fix in Haliax (but isn't quite compatible with some of the CPU-only testing I do)

@dlwh
Copy link
Member

dlwh commented Jun 11, 2024

can you try #622, which is a simpler fix than the haliax thing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants