Skip to content

Commit

Permalink
[Bugfix] Add custom Triton cache manager to resolve MoE MP issue (#6140)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Chih-Chieh-Yang <[email protected]>
  • Loading branch information
tdoublep and cyang49 authored Jul 15, 2024
1 parent a63a4c6 commit eaec4b9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
5 changes: 5 additions & 0 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
Expand Down Expand Up @@ -42,6 +43,10 @@ def _init_executor(self) -> None:
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"

# workaround for https://github.com/vllm-project/vllm/issues/6103
if world_size > 1:
maybe_set_triton_cache_manager()

assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")

Expand Down
6 changes: 6 additions & 0 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)

__all__ = [
"maybe_set_triton_cache_manager",
]
53 changes: 53 additions & 0 deletions vllm/triton_utils/custom_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)

from vllm.logger import init_logger

logger = init_logger(__name__)


def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""

def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")

0 comments on commit eaec4b9

Please sign in to comment.