Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Add custom Triton cache manager to resolve MoE MP issue #6140

Merged
merged 10 commits into from
Jul 15, 2024
Merged
10 changes: 10 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
)


def maybe_set_triton_cache_manager(module: str) -> None:
cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
if cache_manger != module:
os.environ["TRITON_CACHE_MANAGER"] = module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user manually sets this env, can we modify it? Additionally, I suggest adding a log message for clarification

Copy link
Member Author

@tdoublep tdoublep Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have changed it so that we only set it if the user has not. also added a log message



def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.cuda.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
Expand Down Expand Up @@ -428,6 +434,10 @@ def fused_experts(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)

# workaround for https://github.com/vllm-project/vllm/issues/6103
maybe_set_triton_cache_manager(
"vllm.triton_utils.custom_cache_manager:CustomCacheManager")

if override_config:
config = override_config
else:
Expand Down
Empty file added vllm/triton_utils/__init__.py
Empty file.
30 changes: 30 additions & 0 deletions vllm/triton_utils/custom_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)


class CustomCacheManager(FileCacheManager):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document why do we need this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some docstrings


def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")
Loading