Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add custom Triton cache manager to resolve MoE MP issue #6140

Merged
merged 10 commits into from
Jul 15, 2024
Merged
5 changes: 5 additions & 0 deletions vllm/executor/multiproc_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ResultHandler, WorkerMonitor)
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.triton_utils import maybe_set_triton_cache_manager
from vllm.utils import (cuda_device_count_stateless,
error_on_invalid_device_count_status,
get_distributed_init_method, get_open_port,
Expand Down Expand Up @@ -42,6 +43,10 @@ def _init_executor(self) -> None:
if "OMP_NUM_THREADS" not in os.environ:
os.environ["OMP_NUM_THREADS"] = "1"

# workaround for https://github.com/vllm-project/vllm/issues/6103
if world_size > 1:
maybe_set_triton_cache_manager()

assert world_size <= cuda_device_count_stateless(), (
"please set tensor_parallel_size to less than max local gpu count")

Expand Down
6 changes: 6 additions & 0 deletions vllm/triton_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)

__all__ = [
"maybe_set_triton_cache_manager",
]
53 changes: 53 additions & 0 deletions vllm/triton_utils/custom_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import os

from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)

from vllm.logger import init_logger

logger = init_logger(__name__)


def maybe_set_triton_cache_manager() -> None:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger is None:
manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger.info("Setting Triton cache manager to: %s", manager)
os.environ["TRITON_CACHE_MANAGER"] = manager


class CustomCacheManager(FileCacheManager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some docstrings

"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Comment on lines +22 to +25
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If triton 3.0.0 could solve this problem, it'd be better to note here that this custom cache manager can be removed when we upgrade triton.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix for the issue is not yet in v3.0.0, but I guess would be in whatever version comes after that (see my summary here). I will add a comment to that end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""

def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
Loading