Skip to content

Commit

Permalink
add triton CustomCacheManger
Browse files Browse the repository at this point in the history
fixes RHOAIENG-8043

Co-authored-by: Chih-Chieh-Yang <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
2 people authored and dtrifiro committed Jun 18, 2024
1 parent 3217143 commit 3aef43e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Dockerfile.ubi
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ ENV PATH=$VIRTUAL_ENV/bin/:$PATH
RUN microdnf install -y gcc \
&& microdnf clean all

# Custom cache manager (fix for https://issues.redhat.com/browse/RHOAIENG-8043)
COPY extras/custom_cache_manager.py /opt/vllm/lib/python3.11/site-packages/custom_cache_manager.py

# install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
Expand All @@ -189,7 +192,8 @@ ENV HF_HUB_OFFLINE=1 \
PORT=8000 \
HOME=/home/vllm \
VLLM_USAGE_SOURCE=production-docker-image \
VLLM_WORKER_MULTIPROC_METHOD=fork
VLLM_WORKER_MULTIPROC_METHOD=fork \
TRITON_CACHE_MANAGER="custom_cache_manager:CustomCacheManager"

# setup non-root user for OpenShift
RUN umask 002 \
Expand Down
32 changes: 32 additions & 0 deletions extras/custom_cache_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os

from triton.runtime.cache import (FileCacheManager, default_cache_dir,
default_dump_dir, default_override_dir)


class CustomCacheManager(FileCacheManager):

def __init__(self, key, override=False, dump=False):
self.key = key
self.lock_path = None
if dump:
self.cache_dir = default_dump_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
elif override:
self.cache_dir = default_override_dir()
self.cache_dir = os.path.join(self.cache_dir, self.key)
else:
# create cache directory if it doesn't exist
self.cache_dir = os.getenv("TRITON_CACHE_DIR",
"").strip() or default_cache_dir()
if self.cache_dir:
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
self.cache_dir = os.path.join(self.cache_dir, self.key)
self.lock_path = os.path.join(self.cache_dir, "lock")
os.makedirs(self.cache_dir, exist_ok=True)
else:
raise RuntimeError("Could not create or locate cache dir")

print(f"Triton cache dir: {self.cache_dir=}")

0 comments on commit 3aef43e

Please sign in to comment.