-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bugfix] Add custom Triton cache manager to resolve MoE MP issue (#6140)
Signed-off-by: Thomas Parnell <[email protected]> Co-authored-by: Chih-Chieh-Yang <[email protected]>
- Loading branch information
Showing
3 changed files
with
64 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from vllm.triton_utils.custom_cache_manager import ( | ||
maybe_set_triton_cache_manager) | ||
|
||
__all__ = [ | ||
"maybe_set_triton_cache_manager", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
|
||
from triton.runtime.cache import (FileCacheManager, default_cache_dir, | ||
default_dump_dir, default_override_dir) | ||
|
||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
def maybe_set_triton_cache_manager() -> None: | ||
"""Set environment variable to tell Triton to use a | ||
custom cache manager""" | ||
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) | ||
if cache_manger is None: | ||
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" | ||
logger.info("Setting Triton cache manager to: %s", manager) | ||
os.environ["TRITON_CACHE_MANAGER"] = manager | ||
|
||
|
||
class CustomCacheManager(FileCacheManager): | ||
"""Re-implements Triton's cache manager, ensuring that a | ||
unique cache directory is created for each process. This is | ||
needed to avoid collisions when running with tp>1 and | ||
using multi-processing as the distributed backend. | ||
Note this issue was fixed by triton-lang/triton/pull/4295, | ||
but the fix is not yet included in triton==v3.0.0. However, | ||
it should be included in the subsequent version. | ||
""" | ||
|
||
def __init__(self, key, override=False, dump=False): | ||
self.key = key | ||
self.lock_path = None | ||
if dump: | ||
self.cache_dir = default_dump_dir() | ||
self.cache_dir = os.path.join(self.cache_dir, self.key) | ||
self.lock_path = os.path.join(self.cache_dir, "lock") | ||
os.makedirs(self.cache_dir, exist_ok=True) | ||
elif override: | ||
self.cache_dir = default_override_dir() | ||
self.cache_dir = os.path.join(self.cache_dir, self.key) | ||
else: | ||
# create cache directory if it doesn't exist | ||
self.cache_dir = os.getenv("TRITON_CACHE_DIR", | ||
"").strip() or default_cache_dir() | ||
if self.cache_dir: | ||
self.cache_dir = f"{self.cache_dir}_{os.getpid()}" | ||
self.cache_dir = os.path.join(self.cache_dir, self.key) | ||
self.lock_path = os.path.join(self.cache_dir, "lock") | ||
os.makedirs(self.cache_dir, exist_ok=True) | ||
else: | ||
raise RuntimeError("Could not create or locate cache dir") |