From 7fa572640c0b618ca8311a529d66e2185d8fd885 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 15 Jul 2024 19:12:47 +0200 Subject: [PATCH] [Bugfix] Add custom Triton cache manager to resolve MoE MP issue (#6140) Signed-off-by: Thomas Parnell Co-authored-by: Chih-Chieh-Yang Signed-off-by: Alvant --- vllm/executor/multiproc_gpu_executor.py | 5 +++ vllm/triton_utils/__init__.py | 6 +++ vllm/triton_utils/custom_cache_manager.py | 53 +++++++++++++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 vllm/triton_utils/__init__.py create mode 100644 vllm/triton_utils/custom_cache_manager.py diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index dcde27973f8ef..a0e248b2e1992 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -9,6 +9,7 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (cuda_device_count_stateless, error_on_invalid_device_count_status, get_distributed_init_method, get_open_port, @@ -42,6 +43,10 @@ def _init_executor(self) -> None: if "OMP_NUM_THREADS" not in os.environ: os.environ["OMP_NUM_THREADS"] = "1" + # workaround for https://github.com/vllm-project/vllm/issues/6103 + if world_size > 1: + maybe_set_triton_cache_manager() + assert world_size <= cuda_device_count_stateless(), ( "please set tensor_parallel_size to less than max local gpu count") diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py new file mode 100644 index 0000000000000..09843e5d1f30b --- /dev/null +++ b/vllm/triton_utils/__init__.py @@ -0,0 +1,6 @@ +from vllm.triton_utils.custom_cache_manager import ( + maybe_set_triton_cache_manager) + +__all__ = [ + "maybe_set_triton_cache_manager", +] diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py new file mode 100644 index 0000000000000..17039d7ba24c7 --- /dev/null +++ b/vllm/triton_utils/custom_cache_manager.py @@ -0,0 +1,53 @@ +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" + logger.info("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + + +class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir")