Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Device error when running on other cuda device than cuda:0 #215

Open
cornzz opened this issue Aug 28, 2024 · 1 comment · May be fixed by #216
Open

[BUG] Device error when running on other cuda device than cuda:0 #215

cornzz opened this issue Aug 28, 2024 · 1 comment · May be fixed by #216
Labels
bug Something isn't working

Comments

@cornzz
Copy link
Contributor

cornzz commented Aug 28, 2024

Python -VV

Python 3.11.9 (main, Apr 19 2024, 16:48:06) [GCC 11.2.0]

Pip Freeze

conda env export
name: test
channels:
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - asttokens=2.0.5=pyhd3eb1b0_0
  - bzip2=1.0.8=h5eee18b_6
  - ca-certificates=2024.7.2=h06a4308_0
  - comm=0.2.1=py311h06a4308_0
  - debugpy=1.6.7=py311h6a678d5_0
  - decorator=5.1.1=pyhd3eb1b0_0
  - executing=0.8.3=pyhd3eb1b0_0
  - ipykernel=6.28.0=py311h06a4308_0
  - ipython=8.25.0=py311h06a4308_0
  - jedi=0.19.1=py311h06a4308_0
  - jupyter_client=8.6.0=py311h06a4308_0
  - jupyter_core=5.7.2=py311h06a4308_0
  - ld_impl_linux-64=2.38=h1181459_1
  - libffi=3.4.4=h6a678d5_1
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libsodium=1.0.18=h7b6447c_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libuuid=1.41.5=h5eee18b_0
  - matplotlib-inline=0.1.6=py311h06a4308_0
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.6.0=py311h06a4308_0
  - openssl=3.0.14=h5eee18b_0
  - packaging=24.1=py311h06a4308_0
  - parso=0.8.3=pyhd3eb1b0_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pip=24.2=py311h06a4308_0
  - platformdirs=3.10.0=py311h06a4308_0
  - prompt-toolkit=3.0.43=py311h06a4308_0
  - prompt_toolkit=3.0.43=hd3eb1b0_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pure_eval=0.2.2=pyhd3eb1b0_0
  - pygments=2.15.1=py311h06a4308_1
  - python=3.11.9=h955ad1f_0
  - python-dateutil=2.9.0post0=py311h06a4308_2
  - pyzmq=25.1.2=py311h6a678d5_0
  - readline=8.2=h5eee18b_0
  - setuptools=72.1.0=py311h06a4308_0
  - six=1.16.0=pyhd3eb1b0_1
  - sqlite=3.45.3=h5eee18b_0
  - stack_data=0.2.0=pyhd3eb1b0_0
  - tk=8.6.14=h39e8969_0
  - tornado=6.4.1=py311h5eee18b_0
  - traitlets=5.14.3=py311h06a4308_0
  - typing_extensions=4.11.0=py311h06a4308_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - wheel=0.43.0=py311h06a4308_0
  - xz=5.4.6=h5eee18b_1
  - zeromq=4.3.5=h6a678d5_0
  - zlib=1.2.13=h5eee18b_1
  - pip:
      - accelerate==0.33.0
      - aiohappyeyeballs==2.4.0
      - aiohttp==3.10.5
      - aiosignal==1.3.1
      - annotated-types==0.7.0
      - anyio==4.4.0
      - attrs==24.2.0
      - certifi==2024.7.4
      - charset-normalizer==3.3.2
      - click==8.1.7
      - datasets==2.21.0
      - dill==0.3.8
      - distro==1.9.0
      - docstring-parser==0.16
      - evaluate==0.4.2
      - filelock==3.15.4
      - fire==0.6.0
      - frozenlist==1.4.1
      - fsspec==2024.6.1
      - fuzzywuzzy==0.18.0
      - h11==0.14.0
      - httpcore==1.0.5
      - httpx==0.27.1
      - huggingface-hub==0.24.5
      - idna==3.7
      - jieba==0.42.1
      - jinja2==3.1.4
      - jiter==0.5.0
      - joblib==1.4.2
      - jsonschema==4.23.0
      - jsonschema-specifications==2023.12.1
      - llmlingua==0.2.2
      - markupsafe==2.1.5
      - mistral-common==1.3.4
      - mistral-inference==1.3.1
      - mpmath==1.3.0
      - multidict==6.0.5
      - multiprocess==0.70.16
      - networkx==3.3
      - nltk==3.8.1
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.1.3.1
      - nvidia-cuda-cupti-cu12==12.1.105
      - nvidia-cuda-nvrtc-cu12==12.1.105
      - nvidia-cuda-runtime-cu12==12.1.105
      - nvidia-cudnn-cu12==9.1.0.70
      - nvidia-cufft-cu12==11.0.2.54
      - nvidia-curand-cu12==10.3.2.106
      - nvidia-cusolver-cu12==11.4.5.107
      - nvidia-cusparse-cu12==12.1.0.106
      - nvidia-nccl-cu12==2.20.5
      - nvidia-nvjitlink-cu12==12.6.20
      - nvidia-nvtx-cu12==12.1.105
      - openai==1.42.0
      - pandas==2.2.2
      - psutil==6.0.0
      - pyarrow==17.0.0
      - pydantic==2.8.2
      - pydantic-core==2.20.1
      - pytz==2024.1
      - pyyaml==6.0.2
      - referencing==0.35.1
      - regex==2024.7.24
      - requests==2.32.3
      - rouge==1.0.1
      - rpds-py==0.20.0
      - safetensors==0.4.4
      - sentencepiece==0.2.0
      - simple-parsing==0.1.5
      - sniffio==1.3.1
      - sympy==1.13.2
      - termcolor==2.4.0
      - tiktoken==0.7.0
      - tokenizers==0.19.1
      - torch==2.4.0
      - tqdm==4.66.5
      - transformers==4.44.0
      - triton==3.0.0
      - typing-extensions==4.12.2
      - tzdata==2024.1
      - urllib3==2.2.2
      - xformers==0.0.27.post2
      - xxhash==3.5.0
      - yarl==1.9.4
prefix: /home/test

Reproduction Steps

from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

model = Transformer.from_folder("./models/mistral-7B-v0.1", device="cuda:7")
tokenizer = MistralTokenizer.from_file("./models/mistral-7B-v0.1/tokenizer.model").instruct_tokenizer.tokenizer

prompt = "What is the capital of germany? Answer:"
tokens = tokenizer.encode(prompt, bos=True, eos=False)
out_tokens, logprobs = generate([tokens], model, max_tokens=50, temperature=0)
result = tokenizer.decode(out_tokens[0])

Expected Behavior

I am getting the following error when trying to run above code:

ValueError: Attention bias and Query/Key/Value should be on the same device
  query.device: cuda:7
  attn_bias   : cuda:0

This seems related to facebookresearch/xformers#1064, couldn't figure out why this happens yet...

Additional Context

Stack trace
File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/generate.py", line 82, in generate                       
    prelogits = model.forward(                                                                                                                           
                ^^^^^^^^^^^^^^                                                                                                                           
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 276, in forward                    
    h = self.forward_partial(input_ids, seqlens, cache=cache)                                                                                            
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                            
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 258, in forward_partial            
    h = layer(h, freqs_cis, cache_view)                                                                                                                  
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl              
    return self._call_impl(*args, **kwargs)                                                                                                              
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                              
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl                      
    return forward_call(*args, **kwargs)                                                                                                                 
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                 
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 156, in forward                    
    r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)                                                                                 
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                 
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/mistral_inference/transformer.py", line 100, in forward                    
    output = memory_efficient_attention(xq, key, val, None if cache is None else cache.mask)                                                             
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                             
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 276, in memory_efficient_attention    
    return _memory_efficient_attention(                                                                                                                  
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                                  
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 395, in _memory_efficient_attention   
    return _memory_efficient_attention_forward(                                                                                                          
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                                                                          
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py", line 411, in _memory_efficient_attention_for
ward                                                                                                                                                     
    inp.validate_inputs()                                                                                                                                
  File "/home/test/miniconda3/envs/test/lib/python3.11/site-packages/xformers/ops/fmha/common.py", line 145, in validate_inputs                 
    raise ValueError(
ValueError: Attention bias and Query/Key/Value should be on the same device
  query.device: cuda:7
  attn_bias   : cuda:0

Suggested Solutions

No response

@cornzz cornzz added the bug Something isn't working label Aug 28, 2024
@cornzz
Copy link
Contributor Author

cornzz commented Aug 28, 2024

Seems I found the issue, PR: #216

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant