Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Main branch compilation on nvcc 12.6 #1453

Open
roded2 opened this issue Jan 21, 2025 · 2 comments
Open

Main branch compilation on nvcc 12.6 #1453

roded2 opened this issue Jan 21, 2025 · 2 comments

Comments

@roded2
Copy link

roded2 commented Jan 21, 2025

A recent commit ( 7a80279#diff-c5844b57c2c63ac3f0275f6f19f0f0fae920ef927cf4cba54ce3d52defae0aba ) introduced a change to hopper/setup.py which restricts the compilation to nvcc 12.3 (

if bare_metal_version != Version("12.3"): # nvcc 12.3 gives the best perf currently
). The reason for this change according to the comment in the code is because 12.3 gives better perf.
I tried to run hopper/setup.py with the nvcc which is installed on my machine (nvcc 12.6) by changing the if statement in the code attached above to evaluate to false. The compilation fails deterministically due to a segfault in nvcc:

[13/42] /home/user/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/user/flash-attention/hopper/build/temp.linux-x86_64-cpython-312/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.o.d -I/home/user/flash-attention/hopper -I/home/user/flash-attention/csrc/cutlass/include -I/home/user/lib/python3.12/site-packages/torch/include -I/home/user/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/user/lib/python3.12/site-packages/torch/include/TH -I/home/user/lib/python3.12/site-packages/torch/include/THC -I/home/user/include -I/home/user/include/python3.12 -c -c /home/user/flash-attention/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu -o /home/user/flash-attention/hopper/build/temp.linux-x86_64-cpython-312/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' --threads 4 -O3 -std=c++17 --ftemplate-backtrace-limit=0 --use_fast_math --resource-usage -lineinfo -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -gencode arch=compute_90a,code=sm_90a -DFLASHATTENTION_DISABLE_SM8x -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flash_attn_3_cuda -D_GLIBCXX_USE_CXX11_ABI=1 -ccbin /home/user/bin/x86_64-conda-linux-gnu-cc FAILED: /home/user/flash-attention/hopper/build/temp.linux-x86_64-cpython-312/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.o /home/user/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/user/flash-attention/hopper/build/temp.linux-x86_64-cpython-312/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.o.d -I/home/user/flash-attention/hopper -I/home/user/flash-attention/csrc/cutlass/include -I/home/user/lib/python3.12/site-packages/torch/include -I/home/user/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/user/lib/python3.12/site-packages/torch/include/TH -I/home/user/lib/python3.12/site-packages/torch/include/THC -I/home/user/include -I/home/user/include/python3.12 -c -c /home/user/flash-attention/hopper/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu -o /home/user/flash-attention/hopper/build/temp.linux-x86_64-cpython-312/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' --threads 4 -O3 -std=c++17 --ftemplate-backtrace-limit=0 --use_fast_math --resource-usage -lineinfo -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED -DCUTLASS_DEBUG_TRACE_LEVEL=0 -DNDEBUG -gencode arch=compute_90a,code=sm_90a -DFLASHATTENTION_DISABLE_SM8x -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=flash_attn_3_cuda -D_GLIBCXX_USE_CXX11_ABI=1 -ccbin /home/user/bin/x86_64-conda-linux-gnu-cc nvcc warning : incompatible redefinition for option 'compiler-bindir', the last value of this option was used Segmentation fault (core dumped)

Does the compilation has to happen on nvcc 12.3 from a correctness side as well and not just perf? Is it a known issue?
The command I used to compile is MAX_JOBS=64 FLASH_ATTENTION_FORCE_BUILD=TRUE FLASH_ATTENTION_DISABLE_SM80=TRUE python setup.py install but same issue happens with just running python setup.py install (ran on ubuntu 22).
Thanks

@roded2 roded2 changed the title Main branch compilation on mvcc 12.6 Main branch compilation on nvcc 12.6 Jan 21, 2025
@ankutalev
Copy link

Same issue, can reproduce with just building CUDA sources with nvcc

@tridao
Copy link
Member

tridao commented Jan 26, 2025

I haven't tried nvcc 12.6 with all of the features.
You can try disabling more features. With these flags it should compile ok with 12.6 (with lower perf than compiling with 12.3).

FLASH_ATTENTION_DISABLE_BACKWARD=FALSE
FLASH_ATTENTION_DISABLE_SPLIT=TRUE
FLASH_ATTENTION_DISABLE_LOCAL=TRUE
FLASH_ATTENTION_DISABLE_PAGEDKV=TRUE
FLASH_ATTENTION_DISABLE_FP16=TRUE
FLASH_ATTENTION_DISABLE_FP8=TRUE
FLASH_ATTENTION_DISABLE_APPENDKV=TRUE
FLASH_ATTENTION_DISABLE_VARLEN=TRUE
FLASH_ATTENTION_DISABLE_CLUSTER=FALSE
FLASH_ATTENTION_DISABLE_PACKGQA=TRUE
FLASH_ATTENTION_DISABLE_SOFTCAP=TRUE
FLASH_ATTENTION_DISABLE_HDIM64=TRUE
FLASH_ATTENTION_DISABLE_HDIM96=TRUE
FLASH_ATTENTION_DISABLE_HDIM128=FALSE
FLASH_ATTENTION_DISABLE_HDIM192=TRUE
FLASH_ATTENTION_DISABLE_HDIM256=TRUE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants