Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: fused_moe_kernel compile bug #6103

Closed
jeejeelee opened this issue Jul 3, 2024 · 13 comments · Fixed by #6140
Closed

[Bug]: fused_moe_kernel compile bug #6103

jeejeelee opened this issue Jul 3, 2024 · 13 comments · Fixed by #6140
Labels
bug Something isn't working

Comments

@jeejeelee
Copy link
Collaborator

jeejeelee commented Jul 3, 2024

Your current environment

Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.0
Libc version: glibc-2.31

Python version: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-107-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A800 80GB PCIe
GPU 1: NVIDIA A800 80GB PCIe
GPU 2: NVIDIA A800 80GB PCIe
GPU 3: NVIDIA A800 80GB PCIe
GPU 4: NVIDIA A800 80GB PCIe
GPU 5: NVIDIA A800 80GB PCIe
GPU 6: NVIDIA A800 80GB PCIe
GPU 7: NVIDIA A800 80GB PCIe

Nvidia driver version: 550.54.15
cuDNN version: Probably one of the following:
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.9.6
/usr/local/cuda-12.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.9.6
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      46 bits physical, 57 bits virtual
CPU(s):                             112
On-line CPU(s) list:                0-111
Thread(s) per core:                 2
Core(s) per socket:                 28
Socket(s):                          2
NUMA node(s):                       2
Vendor ID:                          GenuineIntel
CPU family:                         6
Model:                              106
Model name:                         Intel(R) Xeon(R) Gold 6330 CPU @ 2.00GHz
Stepping:                           6
CPU MHz:                            2000.000
CPU max MHz:                        3100.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4000.00
Virtualization:                     VT-x
L1d cache:                          2.6 MiB
L1i cache:                          1.8 MiB
L2 cache:                           70 MiB
L3 cache:                           84 MiB
NUMA node0 CPU(s):                  0-27,56-83
NUMA node1 CPU(s):                  28-55,84-111
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI Syscall hardening, KVM SW loop
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect wbnoinvd dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid fsrm md_clear pconfig flush_l1d arch_capabilities

Versions of relevant libraries:
[pip3] mypy==1.9.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] torchvision==0.18.0
[pip3] transformers==4.42.1
[pip3] triton==2.3.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
[conda] torch                     2.3.0                    pypi_0    pypi
[conda] torchvision               0.18.0                   pypi_0    pypi
[conda] transformers              4.42.1                   pypi_0    pypi
[conda] triton                    2.3.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.0.post1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
�[4mGPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	CPU Affinity	NUMA Affinity	GPU NUMA ID�[0m
GPU0	 X 	NV8	PXB	PXB	PXB	PXB	PXB	PXB	0-27,56-83	0		N/A
GPU1	NV8	 X 	PXB	PXB	PXB	PXB	PXB	PXB	0-27,56-83	0		N/A
GPU2	PXB	PXB	 X 	PXB	PXB	PXB	PXB	NV8	0-27,56-83	0		N/A
GPU3	PXB	PXB	PXB	 X 	NV8	PXB	PXB	PXB	0-27,56-83	0		N/A
GPU4	PXB	PXB	PXB	NV8	 X 	PXB	PXB	PXB	0-27,56-83	0		N/A
GPU5	PXB	PXB	PXB	PXB	PXB	 X 	NV8	PXB	0-27,56-83	0		N/A
GPU6	PXB	PXB	PXB	PXB	PXB	NV8	 X 	PXB	0-27,56-83	0		N/A
GPU7	PXB	PXB	NV8	PXB	PXB	PXB	PXB	 X 	0-27,56-83	0		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

🐛 Describe the bug

Description

#5036 will raise the following error when using TP>1. Similar issues also persisted when I tested the Qwen1.5-MoE model
due to the use of the fused_moe_kernel triton kernel.

[rank0]:   File "/root/anaconda3/envs/py310_vllm/lib/python3.10/site-packages/triton/compiler/compiler.py", line 239, in __init__
[rank0]:     self.asm = {
[rank0]:   File "/root/anaconda3/envs/py310_vllm/lib/python3.10/site-packages/triton/compiler/compiler.py", line 240, in <dictcomp>
[rank0]:     file.suffix[1:]: file.read_bytes() if file.suffix[1:] == driver.binary_ext else file.read_text()
[rank0]:   File "/root/anaconda3/envs/py310_vllm/lib/python3.10/pathlib.py", line 1134, in read_text
[rank0]:     with self.open(mode='r', encoding=encoding, errors=errors) as f:
[rank0]:   File "/root/anaconda3/envs/py310_vllm/lib/python3.10/pathlib.py", line 1119, in open
[rank0]:     return self._accessor.open(self, mode, buffering, encoding, errors,
[rank0]: FileNotFoundError: [Errno 2] No such file or directory: '/root/.triton/cache/76c543c0d8742e904f283b49e0ee704b/_sgmv_expand_kernel.json.tmp.pid_450752_42384ec9-f78e-40b0-b440-c41643b19290'

Possible solutions

I believe it's a triton bug. I then debugged the triton code and investigated the related issue. and found that the cause of this issue is well explained here. Until triton officially addresses this issue, we may need to resolve it in a similar manner. Another approach is to set distributed_executor_backend to ray when using the triton kernel.

Reproducible Code

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    logprobs=1,
    prompt_logprobs=1,
    max_tokens=128,
)

# Create an LLM.
model_path = "Qwen1.5-MoE-A2.7B-Chat"  # Other models using the moe triton operator should also work.

llm = LLM(
    model=model_path,
    # enable_lora=True,
    trust_remote_code=True,
    gpu_memory_utilization=0.4,
    tensor_parallel_size=8,
    enforce_eager=False

)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@jeejeelee jeejeelee added the bug Something isn't working label Jul 3, 2024
@jeejeelee
Copy link
Collaborator Author

cc @njhill

@jeejeelee jeejeelee changed the title [Bug]: fused_moe_kernel complie bug [Bug]: fused_moe_kernel compile bug Jul 3, 2024
@tdoublep
Copy link
Member

tdoublep commented Jul 4, 2024

Yeah, we've fixed this issue on our fork (as you found here). Let me create a PR to contribute the fix upstream.

@randxie
Copy link

randxie commented Jul 4, 2024

I was able to work around this bug by building the latest Triton code from source.

@jeejeelee
Copy link
Collaborator Author

I was able to work around this bug by building the latest Triton code from source.

It seems that the triton main has updated the related code, I don't have more time to dive into this update. For now, I believe #6140 can serve as a temporary solution to address this issue. Once VLLM updates the triton version, we can revisit this issue again.

@tdoublep
Copy link
Member

tdoublep commented Jul 5, 2024

@randxie Interesting. I actually tried to test these changes that were merged into Triton main in our fork, but it didn't help. I don't really see much else that has changed in the meantime (at least in python/triton/runtime/cache.py) so I'm wondering how it could be fixed upstream.

Agree that getting the fix from Triton is the long-term solution though.

@randxie
Copy link

randxie commented Jul 5, 2024

@tdoublep I believe there were more changes involved. If you check the log above, Triton tried to load files with naming pattern as *.tmp.pid_*, but when you go into the container, the file is not generated. I haven't spent time on figuring out the exact change, just want to provide a data point that latest Triton has fixed this issue.

@LSC527
Copy link

LSC527 commented Jul 8, 2024

removing the entire triton cache dir every time before run vllm can be a temp workaround, when I run deepseekv2.
This workaround may not always work.

ThomasRaoux pushed a commit to triton-lang/triton that referenced this issue Jul 10, 2024
…4295)

# Summary
there've been multiple issues discussing around the `FileNotFoundError`
on compilation when `CompiledKernel` is trying to read from the listed
ASM files. #2688 #4002 vllm-project/vllm#6103
etc. and there have been some attempts to address it such as #3544 .
This PR attempts to explain the root cause and suggest a fix.

# Why
When a kernel is being compiled, triton first writes IRs to triton cache
dir
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L289)).
Inside of the write operation, the process first writes it to a temp
file unique to the current process (plus a uuid to distinguish between
multiple processes with same PID on different hosts sharing the same
underlying FS)
([ref](https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/cache.py#L124-L130))
and then atomically `os.replace` it to the final file name. Afterwards
the `CompiledKernel` lists all the IRs and reads them
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L362-L367)).

On multiprocess set up this may however result in a race condition.
Let's focus on a case where there's one host with 2 processes on it.
![Triton RC
(1)](https://github.com/triton-lang/triton/assets/43726198/ffc20e0c-0404-4e7a-bd6c-022e710e97b9)

At the time when `pid 1` lists ASMs, the dir may contain temp files
generated from another process `pid 2`. However at the time when `pid 1`
proceeds to read bytes from the listed files, `pid2` may have already
`os.replace`ed its temp files, so `pid 1` will encounter
`FileNotFoundError` when trying to read the temp file generated by `pid
2`. IBM/vllm#35 (comment) also
believes this is the root cause.

# How
There're multiple potential solutions towards this, as mentioned in
IBM/vllm#35 (comment) as well:
- let each process write to a private temp dir instead so `glob` won't
bother taking the temp stuff into consideration
- or, exclude `tmp.pid_*` from `glob`

This PR tries to go with the 1st approach to avoid adding an assumption
on the tmp file pattern (which is only used in `runtime/cache.py`) in
`compiler/compiler.py` but is open to any suggestion. Thanks!

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [x] This PR does not need a test because `not applicable`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
@tdoublep
Copy link
Member

tdoublep commented Jul 10, 2024

There was a PR merged into Triton yesterday that tries to address this issue: triton-lang/triton#4295.

This fix is not yet included in triton==3.0.0 which was released on PyPI yesterday.

@tdoublep
Copy link
Member

tdoublep commented Jul 11, 2024

So I've been digging into this a bit more and here is a summary of my findings:

  • Triton recently released v3.0.0, but it does not seem to include the fix for this issue.
  • Nevertheless, multiple people have reported, before yesterday, that recent Triton nightlies resolve the issue (or similar issues).
  • I can't really understand why any Triton nightly before that PR would actually be safe w.r.t this issue (see the nice diagram from the PR to understand why). Everything is confounded by the fact this is a race condition and not deterministically reproducible.
  • Therefore, in my view to be fully protected from this issue we will need to pull in whatever Triton version comes after 3.0.0, and this requires waiting for torch to upgrade to that and all other dependencies etc.
  • I would therefore propose that we proceed to merge [Bugfix] Add custom Triton cache manager to resolve MoE MP issue  #6140 in the meantime.
  • Main remaining question is whether this issue can potentially also affect Ray deployments.

@jeejeelee
Copy link
Collaborator Author

So I've been digging into this a bit more and here is a summary of my findings:

  • Triton recently released v3.0.0, but it does not seem to include the fix for this issue.
  • Nevertheless, multiple people have reported, before yesterday, that recent Triton nightlies resolve the issue (or similar issues).
  • I can't really understand why any Triton nightly before that PR would actually be safe w.r.t this issue (see the nice diagram from the PR to understand why). Everything is confounded by the fact this is a race condition and not deterministically reproducible.
  • Therefore, in my view to be fully protected from this issue we will need to pull in whatever Triton version comes after 3.0.0, and this requires waiting for torch to upgrade to that and all other dependencies etc.
  • I would therefore propose that we proceed to merge [Bugfix] Add custom Triton cache manager to resolve MoE MP issue  #6140 in the meantime.
  • Main remaining question is whether this issue can potentially also affect Ray deployments.

Firtly, I strongly agree with merging #6140.

For multiple people have https://github.com/vllm-project/vllm/issues/6103#issuecomment-2209298536, before yesterday, that recent Triton nightlies resolve the issue (or similar issues).
This bug is occasional, in my experience, with respect to the MOE model, the occurrence probability is higher when TP=8, typically during the cuda graph capture stage. IMHO, I think it's possible that the error hasn't been reproduced again rather than being addressed

@solesensei
Copy link

solesensei commented Jul 11, 2024

Just for the context, I got the same bug with the latest vllm-server v0.5.1 on mistralai/Mixtral-8x22B-Instruct-v0.1

python -u -m vllm.entrypoints.openai.api_server --model mistralai/Mixtral-8x22B-Instruct-v0.1 --dtype auto --tensor-parallel-size 4 --gpu-memory-utilization 0.95 --swap-space 4 --download-dir /data/vllm-data --max-seq-len-to-capture 8192 --host 0.0.0.0 --port 8079 --max-model-len 16384
INFO 07-11 15:25:06 async_llm_engine.py:646] Received request cmpl-378867b20e274343bfa3228284c0fe9c: prompt: '<s>[INST] How to write a proper yaml to deploy a service in kubernetes?[/INST]', params: SamplingParams(n=1, best_of=1, presence_penalty=0.0, frequency_penalty=0.0, repetition_penalty=1.0, temperature=0.2, top_p=1.0, top_k=-1, min_p=0.0, seed=None, use_beam_search=False, length_penalty=1.0, early_stopping=False, stop=[], stop_token_ids=[], include_stop_str_in_output=False, ignore_eos=False, max_tokens=100, min_tokens=0, logprobs=None, prompt_logprobs=None, skip_special_tokens=True, spaces_between_special_tokens=True, truncate_prompt_tokens=None), prompt_token_ids: [1, 3, 2370, 1066, 4092, 1032, 5747, 1105, 13916, 1066, 16026, 1032, 3140, 1065, 1214, 21331, 29572, 4], lora_request: None.
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop: [Errno 2] No such file or directory: '/root/.triton/cache/2fc210b572cc1ab473ded8f345844946/fused_moe_kernel.cubin.tmp.pid_145_821872', Traceback (most recent call last):
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 64, in start_worker_execution_loop
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     output = self.execute_model(execute_model_req=None)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 271, in execute_model
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1243, in execute_model
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     hidden_or_intermediate_states = model_executable(
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 348, in forward
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     hidden_states = self.model(input_ids, positions, kv_caches,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 276, in forward
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     hidden_states, residual = layer(positions, hidden_states,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 232, in forward
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     hidden_states = self.block_sparse_moe(hidden_states)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/models/mixtral.py", line 95, in forward
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     final_hidden_states = self.experts(hidden_states, router_logits)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 186, in forward
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     final_hidden_states = self.quant_method.apply(
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 68, in apply
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return fused_moe(x,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 574, in fused_moe
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return fused_experts(hidden_states,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 506, in fused_experts
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     invoke_fused_moe_kernel(intermediate_cache2,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 246, in invoke_fused_moe_kernel
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     fused_moe_kernel[grid](
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     self.cache[device][key] = compile(
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/triton/compiler/compiler.py", line 202, in compile
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return CompiledKernel(so_path, metadata_group.get(metadata_filename))
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/triton/compiler/compiler.py", line 230, in __init__
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     self.asm = {
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/site-packages/triton/compiler/compiler.py", line 231, in <dictcomp>
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     file.suffix[1:]: file.read_bytes() if file.suffix[1:] == driver.binary_ext else file.read_text()
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/pathlib.py", line 1134, in read_text
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     with self.open(mode='r', encoding=encoding, errors=errors) as f:
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]   File "/root/.pyenv/versions/3.10.6/lib/python3.10/pathlib.py", line 1119, in open
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226]     return self._accessor.open(self, mode, buffering, encoding, errors,
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226] FileNotFoundError: [Errno 2] No such file or directory: '/root/.triton/cache/2fc210b572cc1ab473ded8f345844946/fused_moe_kernel.cubin.tmp.pid_145_821872'
(VllmWorkerProcess pid=144) ERROR 07-11 15:25:07 multiproc_worker_utils.py:226] 

Reproduced 6 out of 7 times on the first request to the server.

@solesensei
Copy link

Any updates on this?

@tdoublep
Copy link
Member

Fix #6140 is ready from my pov, will try to get it approved and merged asap.

bertmaher pushed a commit to bertmaher/triton that referenced this issue Dec 10, 2024
…riton-lang#4295)

# Summary
there've been multiple issues discussing around the `FileNotFoundError`
on compilation when `CompiledKernel` is trying to read from the listed
ASM files. triton-lang#2688 triton-lang#4002 vllm-project/vllm#6103
etc. and there have been some attempts to address it such as triton-lang#3544 .
This PR attempts to explain the root cause and suggest a fix.

# Why
When a kernel is being compiled, triton first writes IRs to triton cache
dir
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L289)).
Inside of the write operation, the process first writes it to a temp
file unique to the current process (plus a uuid to distinguish between
multiple processes with same PID on different hosts sharing the same
underlying FS)
([ref](https://github.com/triton-lang/triton/blob/c14b033cd979d5c39e5fdb3847c022fa5d71a0c1/python/triton/runtime/cache.py#L124-L130))
and then atomically `os.replace` it to the final file name. Afterwards
the `CompiledKernel` lists all the IRs and reads them
([ref](https://github.com/triton-lang/triton/blob/78091647fccb6825ed9956ff7c0300859856d261/python/triton/compiler/compiler.py#L362-L367)).

On multiprocess set up this may however result in a race condition.
Let's focus on a case where there's one host with 2 processes on it.
![Triton RC
(1)](https://github.com/triton-lang/triton/assets/43726198/ffc20e0c-0404-4e7a-bd6c-022e710e97b9)

At the time when `pid 1` lists ASMs, the dir may contain temp files
generated from another process `pid 2`. However at the time when `pid 1`
proceeds to read bytes from the listed files, `pid2` may have already
`os.replace`ed its temp files, so `pid 1` will encounter
`FileNotFoundError` when trying to read the temp file generated by `pid
2`. IBM/vllm#35 (comment) also
believes this is the root cause.

# How
There're multiple potential solutions towards this, as mentioned in
IBM/vllm#35 (comment) as well:
- let each process write to a private temp dir instead so `glob` won't
bother taking the temp stuff into consideration
- or, exclude `tmp.pid_*` from `glob`

This PR tries to go with the 1st approach to avoid adding an assumption
on the tmp file pattern (which is only used in `runtime/cache.py`) in
`compiler/compiler.py` but is open to any suggestion. Thanks!

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them.

- [x] I am not making a trivial change, such as fixing a typo in a
comment.

- [x] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [ ] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [x] This PR does not need a test because `not applicable`.

- Select one of the following.
  - [x] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
5 participants