From d75cdc09faffe30d18e1da0fec9d549ef06ffafb Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sun, 24 Mar 2024 09:40:09 -0700 Subject: [PATCH] Add flash attention to UBI docker build And update UBI base image, grpcio-tools and accelerate dep versions. --- Dockerfile.ubi | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/Dockerfile.ubi b/Dockerfile.ubi index a67a779bf..dfa397cc5 100644 --- a/Dockerfile.ubi +++ b/Dockerfile.ubi @@ -1,5 +1,5 @@ ## Global Args ################################################################# -ARG BASE_UBI_IMAGE_TAG=9.3-1552 +ARG BASE_UBI_IMAGE_TAG=9.3-1612 ARG PYTHON_VERSION=3.11 ARG PYTORCH_INDEX="https://download.pytorch.org/whl" # ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly" @@ -187,6 +187,25 @@ RUN curl -Lo vllm.whl https://github.com/vllm-project/vllm/releases/download/v${ && rm vllm.whl # compiled extensions located at /workspace/vllm/*.so +#################### FLASH_ATTENTION Build IMAGE #################### +FROM dev as flash-attn-builder + +RUN microdnf install -y git \ + && microdnf clean all + +# max jobs used for build +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} +# flash attention version +ARG flash_attn_version=v2.5.6 +ENV FLASH_ATTN_VERSION=${flash_attn_version} + +WORKDIR /usr/src/flash-attention-v2 + +# Download the wheel or build it if a pre-compiled release doesn't exist +RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \ + --no-build-isolation --no-deps --no-cache-dir + ## Test ######################################################################## FROM dev AS test @@ -199,6 +218,9 @@ ADD . /vllm-workspace/ # copy pytorch extensions separately to avoid having to rebuild # when python code changes COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/ +# Install flash attention (from pre-built wheel) +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir # ignore build dependencies installation because we are using pre-complied extensions RUN rm pyproject.toml RUN --mount=type=cache,target=/root/.cache/pip \ @@ -254,9 +276,13 @@ RUN --mount=type=cache,target=/root/.cache/pip \ pip3 install \ -r requirements.txt \ # additional dependencies for the TGIS gRPC server - grpcio-tools==1.62.0 \ + grpcio-tools==1.62.1 \ # additional dependencies for openai api_server - accelerate==0.27.2 + accelerate==0.28.0 + +# Install flash attention (from pre-built wheel) +RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \ + pip3 install /usr/src/flash-attention-v2/*.whl --no-cache-dir # vLLM will not be installed in site-packages COPY --from=vllm --link /workspace/ ./