diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py
new file mode 100644
index 0000000000000..8350e2705141e
--- /dev/null
+++ b/.buildkite/generate_index.py
@@ -0,0 +1,24 @@
+import argparse
+import os
+
+template = """
+
+
+ Links for vLLM
+ {wheel}
+
+
+"""
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--wheel", help="The wheel path.", required=True)
+args = parser.parse_args()
+
+filename = os.path.basename(args.wheel)
+
+with open("index.html", "w") as f:
+ print(f"Generated index.html for {args.wheel}")
+ # cloudfront requires escaping the '+' character
+ f.write(
+ template.format(wheel=filename,
+ wheel_html_escaped=filename.replace("+", "%2B")))
diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
index 64ba1b32fb074..708e548727cf5 100644
--- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
+++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml
@@ -65,9 +65,9 @@ steps:
- VLLM_USAGE_SOURCE
- HF_TOKEN
- - block: "Run H100 Benchmark"
- key: block-h100
- depends_on: ~
+ #- block: "Run H100 Benchmark"
+ #key: block-h100
+ #depends_on: ~
- label: "H100"
# skip: "use this flag to conditionally skip the benchmark step, useful for PR testing"
diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml
index 2de6fceb0c3fe..51618a2955fb1 100644
--- a/.buildkite/release-pipeline.yaml
+++ b/.buildkite/release-pipeline.yaml
@@ -55,3 +55,18 @@ steps:
password-env: DOCKERHUB_TOKEN
env:
DOCKER_BUILDKIT: "1"
+
+ - block: "Build CPU release image"
+ key: block-cpu-release-image-build
+ depends_on: ~
+
+ - label: "Build and publish CPU release image"
+ depends_on: block-cpu-release-image-build
+ agents:
+ queue: cpu_queue_postmerge
+ commands:
+ - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
+ - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION --progress plain -f Dockerfile.cpu ."
+ - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$RELEASE_VERSION"
+ env:
+ DOCKER_BUILDKIT: "1"
diff --git a/.buildkite/run-gh200-test.sh b/.buildkite/run-gh200-test.sh
index d06604f96f2b8..4fc6d089cc666 100644
--- a/.buildkite/run-gh200-test.sh
+++ b/.buildkite/run-gh200-test.sh
@@ -4,6 +4,9 @@
# It serves a sanity check for compilation and basic model usage.
set -ex
+# Skip the new torch installation during build since we are using the specified version for arm64 in the Dockerfile
+python3 use_existing_torch.py
+
# Try building the docker image
DOCKER_BUILDKIT=1 docker build . \
--target vllm-openai \
diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml
index 44f47fac1c1b3..b563c96343f92 100644
--- a/.buildkite/test-pipeline.yaml
+++ b/.buildkite/test-pipeline.yaml
@@ -224,8 +224,12 @@ steps:
mirror_hardwares: [amd]
source_file_dependencies:
- vllm/model_executor/layers
+ - vllm/model_executor/guided_decoding
- tests/test_logits_processor
- command: pytest -v -s test_logits_processor.py
+ - tests/model_executor/test_guided_processors
+ commands:
+ - pytest -v -s test_logits_processor.py
+ - pytest -v -s model_executor/test_guided_processors.py
- label: Speculative decoding tests # 30min
source_file_dependencies:
diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh
index 7345dd4e66b29..3c756659a715a 100644
--- a/.buildkite/upload-wheels.sh
+++ b/.buildkite/upload-wheels.sh
@@ -23,6 +23,8 @@ wheel="$new_wheel"
version=$(unzip -p "$wheel" '**/METADATA' | grep '^Version: ' | cut -d' ' -f2)
echo "Version: $version"
+normal_wheel="$wheel" # Save the original wheel filename
+
# If the version contains "dev", rename it to v1.0.0.dev for consistency
if [[ $version == *dev* ]]; then
suffix="${version##*.}"
@@ -32,12 +34,38 @@ if [[ $version == *dev* ]]; then
new_version="1.0.0.dev"
fi
new_wheel="${wheel/$version/$new_version}"
- mv -- "$wheel" "$new_wheel"
+ # use cp to keep both files in the artifacts directory
+ cp -- "$wheel" "$new_wheel"
wheel="$new_wheel"
version="$new_version"
fi
# Upload the wheel to S3
+python3 .buildkite/generate_index.py --wheel "$normal_wheel"
+
+# generate index for this commit
aws s3 cp "$wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/"
+
+if [[ $normal_wheel == *"cu118"* ]]; then
+ # if $normal_wheel matches cu118, do not upload the index.html
+ echo "Skipping index files for cu118 wheels"
+else
+ # only upload index.html for cu12 wheels (default wheels)
+ aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html"
+ aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html"
+fi
+
+# generate index for nightly
aws s3 cp "$wheel" "s3://vllm-wheels/nightly/"
+aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/"
+
+if [[ $normal_wheel == *"cu118"* ]]; then
+ # if $normal_wheel matches cu118, do not upload the index.html
+ echo "Skipping index files for cu118 wheels"
+else
+ # only upload index.html for cu12 wheels (default wheels)
+ aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html"
+fi
+
aws s3 cp "$wheel" "s3://vllm-wheels/$version/"
\ No newline at end of file
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index c1051d10a4860..e40ceaaa8b037 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -39,67 +39,68 @@ jobs:
const script = require('.github/workflows/scripts/create_release.js')
await script(github, context, core)
- wheel:
- name: Build Wheel
- runs-on: ${{ matrix.os }}
- needs: release
-
- strategy:
- fail-fast: false
- matrix:
- os: ['ubuntu-20.04']
- python-version: ['3.9', '3.10', '3.11', '3.12']
- pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
- cuda-version: ['11.8', '12.1']
-
- steps:
- - name: Checkout
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
-
- - name: Setup ccache
- uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
- with:
- create-symlink: true
- key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
-
- - name: Set up Linux Env
- if: ${{ runner.os == 'Linux' }}
- run: |
- bash -x .github/workflows/scripts/env.sh
-
- - name: Set up Python
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
- with:
- python-version: ${{ matrix.python-version }}
-
- - name: Install CUDA ${{ matrix.cuda-version }}
- run: |
- bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
-
- - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
- run: |
- bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
-
- - name: Build wheel
- shell: bash
- env:
- CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
- run: |
- bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
- wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
- asset_name=${wheel_name//"linux"/"manylinux1"}
- echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
- echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
-
- - name: Upload Release Asset
- uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- with:
- upload_url: ${{ needs.release.outputs.upload_url }}
- asset_path: ./dist/${{ env.wheel_name }}
- asset_name: ${{ env.asset_name }}
- asset_content_type: application/*
+ # NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow.
+ # wheel:
+ # name: Build Wheel
+ # runs-on: ${{ matrix.os }}
+ # needs: release
+
+ # strategy:
+ # fail-fast: false
+ # matrix:
+ # os: ['ubuntu-20.04']
+ # python-version: ['3.9', '3.10', '3.11', '3.12']
+ # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
+ # cuda-version: ['11.8', '12.1']
+
+ # steps:
+ # - name: Checkout
+ # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ # - name: Setup ccache
+ # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14
+ # with:
+ # create-symlink: true
+ # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }}
+
+ # - name: Set up Linux Env
+ # if: ${{ runner.os == 'Linux' }}
+ # run: |
+ # bash -x .github/workflows/scripts/env.sh
+
+ # - name: Set up Python
+ # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
+ # with:
+ # python-version: ${{ matrix.python-version }}
+
+ # - name: Install CUDA ${{ matrix.cuda-version }}
+ # run: |
+ # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }}
+
+ # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }}
+ # run: |
+ # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }}
+
+ # - name: Build wheel
+ # shell: bash
+ # env:
+ # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size
+ # run: |
+ # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }}
+ # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename)
+ # asset_name=${wheel_name//"linux"/"manylinux1"}
+ # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV"
+ # echo "asset_name=${asset_name}" >> "$GITHUB_ENV"
+
+ # - name: Upload Release Asset
+ # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2
+ # env:
+ # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ # with:
+ # upload_url: ${{ needs.release.outputs.upload_url }}
+ # asset_path: ./dist/${{ env.wheel_name }}
+ # asset_name: ${{ env.asset_name }}
+ # asset_content_type: application/*
# (Danielkinz): This last step will publish the .whl to pypi. Warning: untested
# - name: Publish package
diff --git a/.gitignore b/.gitignore
index ceef6a5fba456..bb7e4d5b244a8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -81,6 +81,8 @@ instance/
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst
+docs/source/getting_started/examples/*.md
+!**/*.template.md
# PyBuilder
.pybuilder/
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bf19b3d227171..83c8033434f3b 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -206,7 +206,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
- set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use")
+ set(CUTLASS_REVISION "v3.6.0" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@@ -223,13 +223,13 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
FetchContent_Declare(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
- GIT_TAG v3.5.1
+ GIT_TAG 8aa95dbb888be6d81c6fbf7169718c5244b53227
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
# Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags.
# So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE
- GIT_SHALLOW TRUE
+ GIT_SHALLOW FALSE
)
endif()
FetchContent_MakeAvailable(cutlass)
@@ -241,7 +241,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
- "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
+ "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
+ "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
+ "csrc/sparse/cutlass/sparse_compressor_entry.cu"
+ "csrc/cutlass_extensions/common.cpp")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
@@ -270,7 +273,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
" in CUDA target architectures")
endif()
- #
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
@@ -323,6 +325,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
+ #
+ # 2:4 Sparse Kernels
+
+ # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
+ # require CUDA 12.2 or later (and only work on Hopper, 9.0/9.0a for now).
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
+ set(SRCS "csrc/sparse/cutlass/sparse_compressor_c3x.cu"
+ "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
+ set_gencode_flags_for_srcs(
+ SRCS "${SRCS}"
+ CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
+ list(APPEND VLLM_EXT_SRC "${SRCS}")
+ list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
+ message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
+ else()
+ if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS)
+ message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
+ "not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
+ "if you intend on running FP8 sparse quantized models on Hopper.")
+ else()
+ message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
+ "in CUDA target architectures")
+ endif()
+ endif()
+
#
# Machete kernels
@@ -404,7 +431,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
- INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
+ INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)
diff --git a/Dockerfile b/Dockerfile
index 123703848749c..153bff9cf565f 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,7 +2,7 @@
# to run the OpenAI compatible server.
# Please update any changes made here to
-# docs/source/dev/dockerfile/dockerfile.rst and
+# docs/source/dev/dockerfile/dockerfile.md and
# docs/source/assets/dev/dockerfile-stages-dependency.png
ARG CUDA_VERSION=12.4.1
@@ -45,17 +45,21 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
WORKDIR /workspace
# install build and runtime dependencies
-COPY requirements-common.txt requirements-common.txt
-COPY requirements-cuda.txt requirements-cuda.txt
-COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install -r requirements-cuda.txt
+# arm64 (GH200) build follows the practice of "use existing pytorch" build,
+# we need to install torch and torchvision from the nightly builds first,
+# pytorch will not appear as a vLLM dependency in all of the following steps
+# after this step
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
+ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
fi
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+ python3 -m pip install -r requirements-cuda.txt
+
# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
@@ -77,11 +81,6 @@ COPY requirements-build.txt requirements-build.txt
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-build.txt
-RUN --mount=type=cache,target=/root/.cache/pip \
- if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
- fi
-
COPY . .
ARG GIT_REPO_CHECK=0
RUN --mount=type=bind,source=.git,target=.git \
@@ -157,8 +156,6 @@ WORKDIR /vllm-workspace
ENV DEBIAN_FRONTEND=noninteractive
ARG TARGETPLATFORM
-COPY requirements-cuda-arm64.txt requirements-cuda-arm64.txt
-
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
@@ -166,7 +163,7 @@ RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt-get update -y \
- && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \
+ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update -y \
@@ -183,17 +180,20 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
# or future versions of triton.
RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
+# arm64 (GH200) build follows the practice of "use existing pytorch" build,
+# we need to install torch and torchvision from the nightly builds first,
+# pytorch will not appear as a vLLM dependency in all of the following steps
+# after this step
+RUN --mount=type=cache,target=/root/.cache/pip \
+ if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
+ python3 -m pip install --index-url https://download.pytorch.org/whl/nightly/cu124 "torch==2.6.0.dev20241210+cu124" "torchvision==0.22.0.dev20241215"; \
+ fi
+
# Install vllm wheel first, so that torch etc will be installed.
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
python3 -m pip install dist/*.whl --verbose
-RUN --mount=type=cache,target=/root/.cache/pip \
- if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- pip uninstall -y torch && \
- python3 -m pip install -r requirements-cuda-arm64.txt; \
- fi
-
RUN --mount=type=cache,target=/root/.cache/pip \
. /etc/environment && \
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
@@ -240,10 +240,11 @@ FROM vllm-base AS vllm-openai
# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
- pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10'; \
+ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
else \
- pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10'; \
+ pip install accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
fi
+
ENV VLLM_USAGE_SOURCE production-docker-image
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
diff --git a/Dockerfile.cpu b/Dockerfile.cpu
index ebe226cf6d148..f163edc27cba8 100644
--- a/Dockerfile.cpu
+++ b/Dockerfile.cpu
@@ -26,10 +26,10 @@ RUN pip install intel_extension_for_pytorch==2.5.0
WORKDIR /workspace
+COPY requirements-build.txt requirements-build.txt
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
RUN --mount=type=cache,target=/root/.cache/pip \
- --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
pip install --upgrade pip && \
pip install -r requirements-build.txt
@@ -37,9 +37,9 @@ FROM cpu-test-1 AS build
WORKDIR /workspace/vllm
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cpu.txt requirements-cpu.txt
RUN --mount=type=cache,target=/root/.cache/pip \
- --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
- --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
pip install -v -r requirements-cpu.txt
COPY . .
diff --git a/Dockerfile.rocm.ubi b/Dockerfile.rocm.ubi
index 8766b995bb555..cc4b81396ef86 100644
--- a/Dockerfile.rocm.ubi
+++ b/Dockerfile.rocm.ubi
@@ -49,8 +49,8 @@ RUN --mount=type=cache,target=/root/.cache/pip \
export version="$(awk -F. '{print $1"."$2}' <<< $ROCM_VERSION)" && \
uv pip install --pre \
--index-url "https://download.pytorch.org/whl/nightly/rocm${version}" \
- torch==2.6.0.dev20241107+rocm${version}\
- torchvision==0.20.0.dev20241107+rocm${version} && \
+ torch==2.6.0.dev20241122+rocm${version}\
+ torchvision==0.20.0.dev20241122+rocm${version} && \
# Install libdrm-amdgpu to avoid errors when retrieving device information (amdgpu.ids: No such file or directory)
microdnf install -y libdrm-amdgpu && \
microdnf clean all
diff --git a/README.md b/README.md
index 93b71ddaccc61..f83c9d759b359 100644
--- a/README.md
+++ b/README.md
@@ -60,7 +60,7 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
- Transformer-like LLMs (e.g., Llama)
-- Mixture-of-Expert LLMs (e.g., Mixtral)
+- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
- Embedding Models (e.g. E5-Mistral)
- Multi-modal LLMs (e.g., LLaVA)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
index 1e5967bd9bf8b..c1b10b3cf8f58 100644
--- a/benchmarks/benchmark_throughput.py
+++ b/benchmarks/benchmark_throughput.py
@@ -4,7 +4,8 @@
import json
import random
import time
-from typing import List, Optional
+from functools import cache
+from typing import Dict, List, Optional, Tuple
import torch
import uvloop
@@ -17,8 +18,11 @@
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.inputs import TextPrompt
+from vllm.lora.request import LoRARequest
+from vllm.lora.utils import get_adapter_absolute_path
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import BeamSearchParams
+from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@@ -28,15 +32,17 @@ class SampleRequest:
Attributes:
prompt: The input text prompt for the model.
- multi_modal_data: Optional dictionary containing multi-modal data (e.g.
- images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
+ multi_modal_data: Optional dictionary containing multi-modal data (e.g.
+ images).
+ lora_request: Optional LoRARequest specifying the LoRA to use.
"""
prompt: str
prompt_len: int
expected_output_len: int
multi_modal_data: Optional[MultiModalDataDict] = None
+ lora_request: Optional[LoRARequest] = None
def _get_prompt_for_image_model(question: str, *, model: str) -> str:
@@ -60,8 +66,30 @@ def _get_prompt_for_image_model(question: str, *, model: str) -> str:
raise ValueError(f"Unsupported model {model}")
+@cache
+def lora_path_on_disk(lora_path: str) -> str:
+ return get_adapter_absolute_path(lora_path)
+
+
+lora_tokenizer_cache: Dict[int, AnyTokenizer] = {}
+
+
+def get_random_lora_request(
+ args: argparse.Namespace
+) -> Tuple[LoRARequest, Optional[AnyTokenizer]]:
+ global lora_tokenizer_cache
+ lora_id = random.randint(1, args.max_loras)
+ lora_request = LoRARequest(lora_name=str(lora_id),
+ lora_int_id=lora_id,
+ lora_path=lora_path_on_disk(args.lora_path))
+ if lora_id not in lora_tokenizer_cache:
+ lora_tokenizer_cache[lora_id] = get_lora_tokenizer(lora_request)
+ return lora_request, lora_tokenizer_cache[lora_id]
+
+
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> List[SampleRequest]:
+
dataset_path: str = args.dataset
num_requests: int = args.num_prompts
fixed_output_len: Optional[int] = args.output_len
@@ -79,7 +107,9 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
# Filter out sequences that are too long or too short
filtered_dataset: List[SampleRequest] = []
- for data in dataset:
+ for data in tqdm(dataset,
+ total=len(filtered_dataset),
+ desc="sampling requests"):
if len(filtered_dataset) == num_requests:
break
@@ -102,9 +132,16 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
continue
prompt = _get_prompt_for_image_model(question=prompt, model=model)
+ request_tokenizer = tokenizer
+ lora_request: Optional[LoRARequest] = None
+ if args.enable_lora:
+ lora_request, lora_tokenizer = get_random_lora_request(args)
+ if lora_tokenizer:
+ request_tokenizer = lora_tokenizer
+
# Tokenize the prompts and completions.
- prompt_token_ids = tokenizer(prompt).input_ids
- completion_token_ids = tokenizer(completion).input_ids
+ prompt_token_ids = request_tokenizer(prompt).input_ids
+ completion_token_ids = request_tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
@@ -118,7 +155,8 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
SampleRequest(prompt=prompt,
prompt_len=prompt_len,
expected_output_len=output_len,
- multi_modal_data=multi_modal_data))
+ multi_modal_data=multi_modal_data,
+ lora_request=lora_request))
return filtered_dataset
@@ -146,14 +184,21 @@ def run_vllm(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
+ lora_requests: Optional[List[LoRARequest]] = None
+ if engine_args.enable_lora:
+ lora_requests = [request.lora_request for request in requests]
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
- llm.generate(prompts, sampling_params, use_tqdm=True)
+ llm.generate(prompts,
+ sampling_params,
+ lora_request=lora_requests,
+ use_tqdm=True)
end = time.perf_counter()
else:
+ assert lora_requests is None, "BeamSearch API does not support LoRA"
prompts = [request.prompt for request in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
@@ -185,6 +230,7 @@ async def run_vllm_async(
# Add the requests to the engine.
prompts: List[TextPrompt] = []
sampling_params: List[SamplingParams] = []
+ lora_requests: List[Optional[LoRARequest]] = []
for request in requests:
prompts.append(
TextPrompt(prompt=request.prompt,
@@ -197,11 +243,16 @@ async def run_vllm_async(
ignore_eos=True,
max_tokens=request.expected_output_len,
))
+ lora_requests.append(request.lora_request)
generators = []
start = time.perf_counter()
- for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
- generator = llm.generate(prompt, sp, request_id=f"test{i}")
+ for i, (prompt, sp,
+ lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
+ generator = llm.generate(prompt,
+ sp,
+ lora_request=lr,
+ request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
@@ -297,6 +348,14 @@ def main(args: argparse.Namespace):
vocab_size = tokenizer.vocab_size
requests = []
for _ in range(args.num_prompts):
+
+ request_tokenizer = tokenizer
+ lora_request: Optional[LoRARequest] = None
+ if args.enable_lora:
+ lora_request, lora_tokenizer = get_random_lora_request(args)
+ if lora_tokenizer:
+ request_tokenizer = lora_tokenizer
+
# Synthesize a prompt with the given input length.
candidate_ids = [
random.randint(0, vocab_size - 1)
@@ -305,8 +364,8 @@ def main(args: argparse.Namespace):
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for _ in range(5): # Max attempts to correct
- candidate_prompt = tokenizer.decode(candidate_ids)
- tokenized_len = len(tokenizer.encode(candidate_prompt))
+ candidate_prompt = request_tokenizer.decode(candidate_ids)
+ tokenized_len = len(request_tokenizer.encode(candidate_prompt))
if tokenized_len == args.input_len:
break
@@ -323,7 +382,8 @@ def main(args: argparse.Namespace):
requests.append(
SampleRequest(prompt=candidate_prompt,
prompt_len=args.input_len,
- expected_output_len=args.output_len))
+ expected_output_len=args.output_len,
+ lora_request=lora_request))
else:
requests = sample_requests(tokenizer, args)
@@ -422,6 +482,14 @@ def main(args: argparse.Namespace):
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
+ # LoRA
+ parser.add_argument(
+ "--lora-path",
+ type=str,
+ default=None,
+ help="Path to the lora adapters to use. This can be an absolute path, "
+ "a relative path, or a Hugging Face model identifier.")
+
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
@@ -431,6 +499,8 @@ def main(args: argparse.Namespace):
assert args.output_len is not None
else:
assert args.input_len is None
+ if args.enable_lora:
+ assert args.lora_path is not None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
@@ -440,6 +510,9 @@ def main(args: argparse.Namespace):
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
+ if args.enable_lora is not None:
+ raise ValueError("LoRA benchmarking is only supported for vLLM"
+ " backend")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
@@ -452,4 +525,7 @@ def main(args: argparse.Namespace):
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
+ if args.enable_lora is not None:
+ raise ValueError("LoRA benchmarking is only supported for vLLM"
+ " backend")
main(args)
diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
new file mode 100644
index 0000000000000..3d1c5e392f9e2
--- /dev/null
+++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py
@@ -0,0 +1,384 @@
+import argparse
+import copy
+import itertools
+import pickle as pkl
+import time
+from typing import Callable, Iterable, List, Tuple
+
+import torch
+import torch.utils.benchmark as TBenchmark
+from torch.utils.benchmark import Measurement as TMeasurement
+from utils import make_rand_sparse_tensors
+from weight_shapes import WEIGHT_SHAPES
+
+from vllm import _custom_ops as ops
+from vllm.utils import FlexibleArgumentParser
+
+DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
+DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
+DEFAULT_TP_SIZES = [1]
+
+
+# bench
+def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
+ **kwargs) -> TMeasurement:
+ min_run_time = 1
+
+ globals = {
+ "args": args,
+ "kwargs": kwargs,
+ "fn": fn,
+ }
+ return TBenchmark.Timer(
+ stmt="fn(*args, **kwargs)",
+ globals=globals,
+ label=label,
+ sub_label=sub_label,
+ description=description,
+ ).blocked_autorange(min_run_time=min_run_time)
+
+
+def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ assert dtype == torch.int8
+ b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
+ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+
+ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
+ torch.bfloat16)
+ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
+
+ if not torch.allclose(out, out_ref):
+ print("Incorrect results")
+ print(out)
+ print(out_ref)
+ else:
+ print("Correct results")
+
+ timers = []
+ # pytorch impl - bfloat16
+ timers.append(
+ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm, a.to(dtype=torch.bfloat16),
+ b.to(dtype=torch.bfloat16)))
+
+ # pytorch impl - float16
+ timers.append(
+ bench_fn(label, sub_label,
+ "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
+ a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
+
+ # cutlass impl
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
+ torch.bfloat16))
+
+ # cutlass with bias
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
+ bias))
+
+ # cutlass sparse impl
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16))
+
+ # cutlass sparse with bias
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16, bias))
+
+ return timers
+
+
+def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ assert dtype == torch.float8_e4m3fn
+ b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n,
+ k)
+ scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
+ bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
+
+ out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b,
+ torch.bfloat16)
+ out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
+
+ if not torch.allclose(out, out_ref):
+ print("Incorrect results")
+ print(out)
+ print(out_ref)
+ else:
+ print("Correct results")
+
+ timers = []
+
+ # pytorch impl w. bf16
+ timers.append(
+ bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
+ torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
+ b.to(dtype=torch.bfloat16, device="cuda")))
+
+ # pytorch impl: bf16 output, without fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16))
+
+ # pytorch impl: bf16 output, with fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.bfloat16,
+ use_fast_accum=True))
+
+ # pytorch impl: fp16 output, without fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16))
+
+ # pytorch impl: fp16 output, with fp8 fast accum
+ timers.append(
+ bench_fn(label,
+ sub_label,
+ "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
+ torch._scaled_mm,
+ a,
+ b,
+ scale_a=scale_a,
+ scale_b=scale_b,
+ out_dtype=torch.float16,
+ use_fast_accum=True))
+
+ # cutlass impl: bf16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
+ ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
+ torch.bfloat16))
+
+ # cutlass impl: bf16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16))
+
+ # cutlass impl: fp16 output
+ timers.append(
+ bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.float16))
+
+ # cutlass impl: bf16 output, with bias
+ timers.append(
+ bench_fn(label, sub_label,
+ "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.bfloat16, bias))
+
+ # cutlass impl: fp16 output, with bias
+ timers.append(
+ bench_fn(label, sub_label,
+ "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
+ ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a,
+ scale_b, torch.float16, bias.to(dtype=torch.float16)))
+
+ return timers
+
+
+def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
+ sub_label: str) -> Iterable[TMeasurement]:
+ if dtype == torch.int8:
+ return bench_int8(dtype, m, k, n, label, sub_label)
+ if dtype == torch.float8_e4m3fn:
+ return bench_fp8(dtype, m, k, n, label, sub_label)
+ raise ValueError("unsupported type")
+
+
+# runner
+def print_timers(timers: Iterable[TMeasurement]):
+ compare = TBenchmark.Compare(timers)
+ compare.print()
+
+
+def run(dtype: torch.dtype,
+ MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
+ results = []
+ for m, k, n in MKNs:
+ timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
+ f"MKN=({m}x{k}x{n})")
+ print_timers(timers)
+ results.extend(timers)
+
+ return results
+
+
+# output makers
+def make_output(data: Iterable[TMeasurement],
+ MKNs: Iterable[Tuple[int, int, int]],
+ base_description: str,
+ timestamp=None):
+ print(f"== All Results {base_description} ====")
+ print_timers(data)
+
+ # pickle all the results
+ timestamp = int(time.time()) if timestamp is None else timestamp
+ with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
+ pkl.dump(data, f)
+
+
+# argparse runners
+
+
+def run_square_bench(args):
+ dim_sizes = list(
+ range(args.dim_start, args.dim_end + 1, args.dim_increment))
+ MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
+ data = run(args.dtype, MKNs)
+
+ make_output(data, MKNs, f"square_bench-{args.dtype}")
+
+
+def run_range_bench(args):
+ dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
+ n = len(dim_sizes)
+ Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
+ Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
+ Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
+ MKNs = list(zip(Ms, Ks, Ns))
+ data = run(args.dtype, MKNs)
+
+ make_output(data, MKNs, f"range_bench-{args.dtype}")
+
+
+def run_model_bench(args):
+ print("Benchmarking models:")
+ for i, model in enumerate(args.models):
+ print(f"[{i}] {model}")
+
+ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
+ KNs = []
+ for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
+ KN[tp_split_dim] = KN[tp_split_dim] // tp_size
+ KNs.append(KN)
+ return KNs
+
+ model_bench_data = []
+ models_tps = list(itertools.product(args.models, args.tp_sizes))
+ for model, tp_size in models_tps:
+ Ms = args.batch_sizes
+ KNs = model_shapes(model, tp_size)
+ MKNs = []
+ for m in Ms:
+ for k, n in KNs:
+ MKNs.append((m, k, n))
+
+ data = run(args.dtype, MKNs)
+ model_bench_data.append(data)
+
+ # Print all results
+ for data, model_tp in zip(model_bench_data, models_tps):
+ model, tp_size = model_tp
+ print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
+ print_timers(data)
+
+ timestamp = int(time.time())
+
+ all_data = []
+ for d in model_bench_data:
+ all_data.extend(d)
+ # pickle all data
+ with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
+ pkl.dump(all_data, f)
+
+
+if __name__ == '__main__':
+
+ def to_torch_dtype(dt):
+ if dt == "int8":
+ return torch.int8
+ if dt == "fp8":
+ return torch.float8_e4m3fn
+ raise ValueError("unsupported dtype")
+
+ parser = FlexibleArgumentParser(
+ description="""
+Benchmark Cutlass GEMM.
+
+ To run square GEMMs:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
+
+ To run constant N and K and sweep M:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
+
+ To run dimensions from a model:
+ python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
+
+ Output:
+ - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
+ """, # noqa: E501
+ formatter_class=argparse.RawTextHelpFormatter)
+
+ parser.add_argument("--dtype",
+ type=to_torch_dtype,
+ required=True,
+ help="Available options are ['int8', 'fp8']")
+ subparsers = parser.add_subparsers(dest="cmd")
+
+ square_parser = subparsers.add_parser("square_bench")
+ square_parser.add_argument("--dim-start", type=int, required=True)
+ square_parser.add_argument("--dim-end", type=int, required=True)
+ square_parser.add_argument("--dim-increment", type=int, required=True)
+ square_parser.set_defaults(func=run_square_bench)
+
+ range_parser = subparsers.add_parser("range_bench")
+ range_parser.add_argument("--dim-start", type=int, required=True)
+ range_parser.add_argument("--dim-end", type=int, required=True)
+ range_parser.add_argument("--dim-increment", type=int, required=True)
+ range_parser.add_argument("--m-constant", type=int, default=None)
+ range_parser.add_argument("--n-constant", type=int, default=None)
+ range_parser.add_argument("--k-constant", type=int, default=None)
+ range_parser.set_defaults(func=run_range_bench)
+
+ model_parser = subparsers.add_parser("model_bench")
+ model_parser.add_argument("--models",
+ nargs="+",
+ type=str,
+ default=DEFAULT_MODELS,
+ choices=WEIGHT_SHAPES.keys())
+ model_parser.add_argument("--tp-sizes",
+ nargs="+",
+ type=int,
+ default=DEFAULT_TP_SIZES)
+ model_parser.add_argument("--batch-sizes",
+ nargs="+",
+ type=int,
+ default=DEFAULT_BATCH_SIZES)
+ model_parser.set_defaults(func=run_model_bench)
+
+ args = parser.parse_args()
+ args.func(args)
diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py
new file mode 100644
index 0000000000000..ef06fcd6604dd
--- /dev/null
+++ b/benchmarks/cutlass_benchmarks/utils.py
@@ -0,0 +1,96 @@
+# Cutlass bench utils
+from typing import Iterable, Tuple
+
+import torch
+
+import vllm._custom_ops as ops
+
+
+def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
+ finfo = torch.finfo(torch.float8_e4m3fn)
+ return torch.round(tensor.clamp(
+ min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
+
+
+def to_int8(tensor: torch.Tensor) -> torch.Tensor:
+ return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
+
+
+def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
+ return tensor.to(dtype=torch.bfloat16)
+
+
+def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
+ return tensor.to(dtype=torch.float16)
+
+
+def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
+ k: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device='cuda') * 5
+ b = torch.randn((n, k), device='cuda').t() * 5
+
+ if dtype == torch.int8:
+ return to_int8(a), to_int8(b)
+ if dtype == torch.float8_e4m3fn:
+ return to_fp8(a), to_fp8(b)
+
+ raise ValueError("unsupported dtype")
+
+
+def prune_to_2_4(tensor):
+ # Reshape tensor to [N, 4] where N is number of groups of 4
+ original_shape = tensor.shape
+ reshaped = tensor.reshape(-1, 4)
+
+ # Get indices of top 2 absolute values in each group of 4
+ _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
+
+ # Create binary mask
+ mask = torch.zeros_like(reshaped)
+ mask.scatter_(dim=1,
+ index=indices,
+ src=torch.ones_like(indices, dtype=mask.dtype))
+
+ # Apply mask and reshape back
+ pruned = reshaped * mask
+
+ # Turn all -0.0 to 0.0
+ pruned[pruned == -0.0] = 0.0
+
+ return pruned.reshape(original_shape)
+
+
+def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int,
+ k: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ a = torch.randn((m, k), device='cuda') * 5
+ b = torch.randn((n, k), device='cuda').t() * 5
+
+ b = prune_to_2_4(b.t()).t()
+
+ if dtype == torch.int8:
+ a, b = to_int8(a), to_int8(b)
+ elif dtype == torch.float8_e4m3fn:
+ a, b = to_fp8(a), to_fp8(b)
+ elif dtype == torch.float16:
+ a, b = to_fp16(a), to_fp16(b)
+ elif dtype == torch.bfloat16:
+ a, b = to_bf16(a), to_bf16(b)
+ else:
+ raise ValueError("unsupported dtype")
+
+ b_compressed, e = ops.cutlass_sparse_compress(b.t())
+
+ # Compressed B, Metadata, Original A, B
+ return b_compressed, e, a, b
+
+
+def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype,
+ m: int, n: int, k: int) -> \
+ Tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
+ ABs = []
+ for _ in range(num_tensors):
+ b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
+ if b_comp is not None:
+ ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
+ BComps, Es, As, Bs = zip(*ABs)
+ return list(BComps), list(Es), list(As), list(Bs)
diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
index 63cf5d50cac75..d0353bc8cb42a 100644
--- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
@@ -8,6 +8,7 @@
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
+from utils import make_rand_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
@@ -17,31 +18,6 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
-# helpers
-
-
-def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
- finfo = torch.finfo(torch.float8_e4m3fn)
- return torch.round(tensor.clamp(
- min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
-
-
-def to_int8(tensor: torch.Tensor) -> torch.Tensor:
- return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
-
-
-def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
- k: int) -> Tuple[torch.Tensor, torch.Tensor]:
- a = torch.randn((m, k), device='cuda') * 5
- b = torch.randn((n, k), device='cuda').t() * 5
-
- if dtype == torch.int8:
- return to_int8(a), to_int8(b)
- if dtype == torch.float8_e4m3fn:
- return to_fp8(a), to_fp8(b)
-
- raise ValueError("unsupported dtype")
-
# bench
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
@@ -386,4 +362,4 @@ def to_torch_dtype(dt):
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
- args.func(args)
+ args.func(args)
\ No newline at end of file
diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py
index 25ec9d6028627..d58fb0bf86374 100644
--- a/benchmarks/cutlass_benchmarks/weight_shapes.py
+++ b/benchmarks/cutlass_benchmarks/weight_shapes.py
@@ -40,4 +40,4 @@
([8192, 57344], 1),
([28672, 8192], 0),
],
-}
+}
\ No newline at end of file
diff --git a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
index 2924ea4a49f54..94999630bae12 100644
--- a/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
+++ b/benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
@@ -10,7 +10,8 @@ set -ex
kill_gpu_processes() {
# kill all processes on GPU.
- pkill -f pt_main_thread
+ pgrep pt_main_thread | xargs -r kill -9
+ pgrep python3 | xargs -r kill -9
sleep 10
# remove vllm config file
@@ -54,7 +55,7 @@ benchmark() {
CUDA_VISIBLE_DEVICES=0 python3 \
-m vllm.entrypoints.openai.api_server \
- --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --model $model \
--port 8100 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
@@ -64,7 +65,7 @@ benchmark() {
CUDA_VISIBLE_DEVICES=1 python3 \
-m vllm.entrypoints.openai.api_server \
- --model meta-llama/Meta-Llama-3.1-8B-Instruct \
+ --model $model \
--port 8200 \
--max-model-len 10000 \
--gpu-memory-utilization 0.6 \
@@ -87,7 +88,7 @@ benchmark() {
--port 8100 \
--save-result \
--result-dir $results_folder \
- --result-filename disagg_prefill_2xtp4.json \
+ --result-filename disagg_prefill_tp1.json \
--request-rate "inf"
@@ -105,7 +106,7 @@ benchmark() {
--port 8200 \
--save-result \
--result-dir $results_folder \
- --result-filename disagg_prefill_2xtp4.json \
+ --result-filename disagg_prefill_tp1_overhead.json \
--request-rate "$qps"
kill_gpu_processes
@@ -118,7 +119,7 @@ main() {
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
- pip install quart httpx
+ pip install quart httpx datasets
cd "$(dirname "$0")"
diff --git a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
index d8d9e976dce76..eb5d891d0d4a5 100644
--- a/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
+++ b/benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
@@ -1,13 +1,12 @@
#!/bin/bash
-# Requirement: 8x H100 GPUs.
+# Requirement: 2x GPUs.
-# Model: neuralmagic/Meta-Llama-3-70B-Instruct-FP8-KV
-# Query: 2048 input tokens, 11 output tokens, QPS 4, 500 requests
-# Resource: 8x H100
+# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
+# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
+# Resource: 2x GPU
# Approaches:
-# 1. Chunked prefill: 1 vllm instance with tp=8
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
# Prefilling instance: max_output_token=1
@@ -114,7 +113,6 @@ benchmark() {
--request-rate "$qps"
sleep 2
-
}
@@ -123,8 +121,9 @@ main() {
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get -y install jq)
(which socat) || (apt-get -y install socat)
+ (which lsof) || (apt-get -y install lsof)
- pip install quart httpx matplotlib aiohttp
+ pip install quart httpx matplotlib aiohttp datasets
cd "$(dirname "$0")"
diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp
new file mode 100644
index 0000000000000..ba9f40a230c8e
--- /dev/null
+++ b/csrc/core/math.hpp
@@ -0,0 +1,7 @@
+#include
+#include
+
+inline uint32_t next_pow_2(uint32_t const num) {
+ if (num <= 1) return num;
+ return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
+}
\ No newline at end of file
diff --git a/csrc/cutlass_extensions/common.cpp b/csrc/cutlass_extensions/common.cpp
new file mode 100644
index 0000000000000..3d2093ab94297
--- /dev/null
+++ b/csrc/cutlass_extensions/common.cpp
@@ -0,0 +1,11 @@
+#include "cutlass_extensions/common.hpp"
+
+int32_t get_sm_version_num() {
+ int32_t major_capability, minor_capability;
+ cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
+ 0);
+ cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
+ 0);
+ int32_t version_num = major_capability * 10 + minor_capability;
+ return version_num;
+}
\ No newline at end of file
diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp
new file mode 100644
index 0000000000000..85e359aa57113
--- /dev/null
+++ b/csrc/cutlass_extensions/common.hpp
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include
+#include "cuda_runtime.h"
+#include
+
+/**
+ * Helper function for checking CUTLASS errors
+ */
+#define CUTLASS_CHECK(status) \
+ { \
+ cutlass::Status error = status; \
+ TORCH_CHECK(error == cutlass::Status::kSuccess, \
+ cutlassGetStatusString(error)); \
+ }
+
+/**
+ * Panic wrapper for unwinding CUDA runtime errors
+ */
+#define CUDA_CHECK(status) \
+ { \
+ cudaError_t error = status; \
+ TORCH_CHECK(error == cudaSuccess, cudaGetErrorString(error)); \
+ }
+
+inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
+ int max_shared_mem_per_block_opt_in = 0;
+ cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
+ cudaDevAttrMaxSharedMemoryPerBlockOptin,
+ device);
+ return max_shared_mem_per_block_opt_in;
+}
+
+int32_t get_sm_version_num();
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
index c69e87999ae71..26f7423fd7455 100644
--- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
@@ -1,3 +1,5 @@
+#pragma once
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
/*
diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
index 95764ecddc79f..c723adf126422 100644
--- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
@@ -1,3 +1,5 @@
+#pragma once
+
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
/*
@@ -36,13 +38,13 @@ struct ScaledEpilogueBase {
// Don't want to support nullptr by default
template
using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
// Don't want to support nullptr by default
template
using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast<
- 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T,
+ 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T,
Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>;
// This utility function constructs the arguments for the load descriptors
diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu
index fff7ce34c838a..24341d63fb1f8 100644
--- a/csrc/moe/moe_align_sum_kernels.cu
+++ b/csrc/moe/moe_align_sum_kernels.cu
@@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
}
}
+// TODO(simon): this is temporarily adapted from
+// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
+// we did this to unblock Deepseek V3 but there should be a better
+// implementation to manage shared memory.
+template
+__global__ void moe_align_block_size_global_mem_kernel(
+ scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids,
+ int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts,
+ int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) {
+ const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
+ const size_t start_idx = threadIdx.x * tokens_per_thread;
+
+ for (int i = 0; i < num_experts; ++i) {
+ tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
+ }
+
+ /**
+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
+ * which counts how many tokens in the token shard of thread_index are
+ * assigned to expert expert_index.
+ */
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+ ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])];
+ }
+
+ __syncthreads();
+
+ // For each expert we accumulate the token counts from the different threads.
+ if (threadIdx.x < num_experts) {
+ tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
+ for (int i = 1; i <= blockDim.x; ++i) {
+ tokens_cnts[index(num_experts, i, threadIdx.x)] +=
+ tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
+ }
+ }
+
+ __syncthreads();
+
+ // We accumulate the token counts of all experts in thread 0.
+ if (threadIdx.x == 0) {
+ cumsum[0] = 0;
+ for (int i = 1; i <= num_experts; ++i) {
+ cumsum[i] = cumsum[i - 1] +
+ CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
+ block_size) *
+ block_size;
+ }
+ *total_tokens_post_pad = cumsum[num_experts];
+ }
+
+ __syncthreads();
+
+ /**
+ * For each expert, each thread processes the tokens of the corresponding
+ * blocks and stores the corresponding expert_id for each block.
+ */
+ if (threadIdx.x < num_experts) {
+ for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
+ i += block_size) {
+ expert_ids[i / block_size] = threadIdx.x;
+ }
+ }
+
+ /**
+ * Each thread processes a token shard, calculating the index of each token
+ * after sorting by expert number. Given the example topk_ids =
+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
+ * padding value(preset in python).
+ */
+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
+ int32_t expert_id = topk_ids[i];
+ /** The cumsum[expert_id] stores the starting index of the tokens that the
+ * expert with expert_id needs to process, and
+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
+ * processed by the expert with expert_id within the current thread's token
+ * shard.
+ */
+ int32_t rank_post_pad =
+ tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
+ cumsum[expert_id];
+ sorted_token_ids[rank_post_pad] = i;
+ ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)];
+ }
+}
+
template
__global__ void moe_sum_kernel(
scalar_t* __restrict__ out, // [..., d]
@@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- VLLM_DISPATCH_INTEGRAL_TYPES(
- topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
- // tensors
- const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
- const int32_t shared_mem =
- ((num_thread + 1) * num_experts + (num_experts + 1)) *
- sizeof(int32_t);
-
- // set dynamic shared mem
- auto kernel = vllm::moe::moe_align_block_size_kernel;
- AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
- (void*)kernel, shared_mem));
- kernel<<<1, num_thread, shared_mem, stream>>>(
- topk_ids.data_ptr(), sorted_token_ids.data_ptr(),
- experts_ids.data_ptr(),
- num_tokens_post_pad.data_ptr(), num_experts, block_size,
- topk_ids.numel());
- });
+
+ // If we have very large number of experts, we can no longer use shared
+ // memory.
+ // TODO(simon): the right solution should be calculating the exact right
+ // amount of shared memory and use that. The num_experts >= 256 is just a
+ // temporary solution to unblock Deepseek V3.
+ if (num_experts >= 256) {
+ VLLM_DISPATCH_INTEGRAL_TYPES(
+ topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
+ // tensors
+ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
+
+ const int32_t mem_tokens_cnts =
+ ((num_experts + 1) * num_experts) * sizeof(int32_t);
+ const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t);
+ // allocate global memory
+ int32_t* tokens_cnts;
+ int32_t* cumsum;
+ cudaMalloc(&tokens_cnts, mem_tokens_cnts);
+ cudaMalloc(&cumsum, mem_cumsum);
+
+ auto kernel =
+ vllm::moe::moe_align_block_size_global_mem_kernel;
+ kernel<<<1, num_thread, 0, stream>>>(
+ topk_ids.data_ptr(),
+ sorted_token_ids.data_ptr(),
+ experts_ids.data_ptr(),
+ num_tokens_post_pad.data_ptr(), num_experts, block_size,
+ topk_ids.numel(), tokens_cnts, cumsum);
+ cudaFree(tokens_cnts);
+ cudaFree(cumsum);
+ });
+ } else {
+ VLLM_DISPATCH_INTEGRAL_TYPES(
+ topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
+ // tensors
+ const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE);
+ const int32_t shared_mem =
+ ((num_thread + 1) * num_experts + (num_experts + 1)) *
+ sizeof(int32_t);
+
+ // set dynamic shared mem
+ auto kernel = vllm::moe::moe_align_block_size_kernel;
+ AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
+ (void*)kernel, shared_mem));
+ kernel<<<1, num_thread, shared_mem, stream>>>(
+ topk_ids.data_ptr(),
+ sorted_token_ids.data_ptr(),
+ experts_ids.data_ptr(),
+ num_tokens_post_pad.data_ptr(), num_experts, block_size,
+ topk_ids.numel());
+ });
+ }
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
diff --git a/csrc/ops.h b/csrc/ops.h
index 816b471d062d2..347c502845d8f 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -162,6 +162,17 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj,
c10::optional const& azp,
c10::optional const& bias);
+
+bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
+
+void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& b, torch::Tensor const& e,
+ torch::Tensor const& a_scales,
+ torch::Tensor const& b_scales,
+ c10::optional const& bias);
+
+bool cutlass_sparse_compress_entry(torch::Tensor& a_compressed,
+ torch::Tensor& e, torch::Tensor const& a);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
diff --git a/csrc/quantization/cutlass_w8a8/common.hpp b/csrc/quantization/cutlass_w8a8/common.hpp
deleted file mode 100644
index bf04bb400790f..0000000000000
--- a/csrc/quantization/cutlass_w8a8/common.hpp
+++ /dev/null
@@ -1,27 +0,0 @@
-#pragma once
-
-#include "cutlass/cutlass.h"
-#include
-
-/**
- * Helper function for checking CUTLASS errors
- */
-#define CUTLASS_CHECK(status) \
- { \
- TORCH_CHECK(status == cutlass::Status::kSuccess, \
- cutlassGetStatusString(status)) \
- }
-
-inline uint32_t next_pow_2(uint32_t const num) {
- if (num <= 1) return num;
- return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
-}
-
-inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
- int max_shared_mem_per_block_opt_in = 0;
- cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
- cudaDevAttrMaxSharedMemoryPerBlockOptin,
- device);
- return max_shared_mem_per_block_opt_in;
-}
-
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
index d03242f44ab1d..f2fae4b66d651 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
@@ -21,15 +21,16 @@
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
-#include "common.hpp"
+#include "core/math.hpp"
+#include "cutlass_extensions/common.hpp"
// clang-format on
using namespace cute;
/*
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm80EVT,
+ Epilogues defined in,
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
+ must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
index 33581a63d4c3d..123f4359c0d1a 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
@@ -1,384 +1,18 @@
-// clang-format will break include orders
-// clang-format off
#include
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
-#include
+ #include "scaled_mm_c3x_sm90_fp8_dispatch.cuh"
+ #include "scaled_mm_c3x_sm90_int8_dispatch.cuh"
-#include
-
-#include
-#include
-#include
-
-#include "cutlass/cutlass.h"
-
-#include "cute/tensor.hpp"
-#include "cute/atom/mma_atom.hpp"
-#include "cutlass/numeric_types.h"
-
-#include "cutlass/gemm/device/gemm_universal_adapter.h"
-#include "cutlass/gemm/kernel/gemm_universal.hpp"
-#include "cutlass/epilogue/collective/collective_builder.hpp"
-#include "cutlass/gemm/collective/collective_builder.hpp"
-
-#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
-#include "common.hpp"
-// clang-format on
-
-using namespace cute;
+ #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
using namespace vllm;
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper) or later.
-
- Epilogue functions can be defined to post-process the output before it is
- written to GPU memory.
- Epilogues must contain a public type named EVTCompute of type Sm90EVT,
- as well as a static prepare_args function that constructs an
- EVTCompute::Arguments struct.
*/
-namespace {
-
-// A wrapper for the GEMM kernel that is used to guard against compilation on
-// architectures that will never use the kernel. The purpose of this is to
-// reduce the size of the compiled binary.
-// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
-// into code that will be executed on the device where it is defined.
-template
-struct enable_sm90_or_later : Kernel {
- template
- CUTLASS_DEVICE void operator()(Args&&... args) {
- #if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
- Kernel::operator()(std::forward(args)...);
- #endif
- }
-};
-template typename Epilogue_,
- typename TileShape, typename ClusterShape, typename KernelSchedule,
- typename EpilogueSchedule>
-struct cutlass_3x_gemm {
- using ElementAB = ElementAB_;
- using ElementD = ElementD_;
- using ElementAcc =
- typename std::conditional, int32_t,
- float>::type;
-
- using EpilogueDescriptor =
- cutlass::epilogue::collective::detail::EpilogueDescriptor<
- TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
- ElementD, EpilogueSchedule>;
-
- using Epilogue = Epilogue_;
-
- using StrideD = Stride, Int<0>>;
- using ElementC = void;
- using StrideC = StrideD;
-
- using EVTCompute = typename Epilogue::EVTCompute;
-
- using CollectiveEpilogue =
- typename cutlass::epilogue::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
- ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
- ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
- EpilogueSchedule, EVTCompute>::CollectiveOp;
-
- static constexpr size_t CEStorageSize =
- sizeof(typename CollectiveEpilogue::SharedStorage);
- using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
- static_cast(CEStorageSize)>;
-
- // clang-format off
- using CollectiveMainloop =
- typename cutlass::gemm::collective::CollectiveBuilder<
- cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
- ElementAB, cutlass::layout::RowMajor, 16,
- ElementAB, cutlass::layout::ColumnMajor, 16,
- ElementAcc, TileShape, ClusterShape,
- Stages,
- KernelSchedule>::CollectiveOp;
- // clang-format on
-
- using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue,
- cutlass::gemm::PersistentScheduler>>;
-
- struct GemmKernel : public KernelType {};
-};
-
-template
-void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... epilogue_params) {
- using ElementAB = typename Gemm::ElementAB;
- using ElementD = typename Gemm::ElementD;
-
- int32_t m = a.size(0);
- int32_t n = b.size(1);
- int32_t k = a.size(1);
-
- int64_t lda = a.stride(0);
- int64_t ldb = b.stride(1);
- int64_t ldc = out.stride(0);
-
- using StrideA = Stride, int64_t>;
- using StrideB = Stride, int64_t>;
- using StrideC = typename Gemm::StrideC;
-
- StrideA a_stride{lda, Int<1>{}, 0};
- StrideB b_stride{ldb, Int<1>{}, 0};
- StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
-
- using GemmKernel = typename Gemm::GemmKernel;
- typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
-
- auto a_ptr = static_cast(a.data_ptr());
- auto b_ptr = static_cast(b.data_ptr());
- typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
- b_stride};
-
- auto c_ptr = static_cast(out.data_ptr());
- typename GemmKernel::EpilogueArguments epilogue_args{
- Gemm::Epilogue::prepare_args(
- std::forward(epilogue_params)...),
- c_ptr, c_stride, c_ptr, c_stride};
-
- typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
- prob_shape, mainloop_args, epilogue_args};
-
- // Launch the CUTLASS GEMM kernel.
- using GemmOp = cutlass::gemm::device::GemmUniversalAdapter;
- GemmOp gemm_op;
- CUTLASS_CHECK(gemm_op.can_implement(args));
-
- size_t workspace_size = gemm_op.get_workspace_size(args);
- auto const workspace_options =
- torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
- auto workspace = torch::empty(workspace_size, workspace_options);
-
- auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
-
- cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
- CUTLASS_CHECK(status);
-}
-
-template typename Epilogue>
-struct sm90_fp8_config_default {
- // M in (128, inf)
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_fp8_config_M128 {
- // M in (64, 128]
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_fp8_config_M64 {
- // M in [1, 64]
- static_assert(std::is_same());
- using KernelSchedule =
- cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _128>;
- using ClusterShape = Shape<_1, _8, _1>;
-
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_default {
- // For M > 128 and any N
- static_assert(std::is_same());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_128, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M128 {
- // For M in (64, 128] and any N
- static_assert(std::is_same());
- using KernelSchedule =
- typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _128>;
- using ClusterShape = Shape<_2, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M64 {
- // For M in (32, 64] and any N
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _1, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M32_NBig {
- // For M in [1, 32] and N >= 8192
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _128, _256>;
- using ClusterShape = Shape<_1, _4, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-template typename Epilogue>
-struct sm90_int8_config_M32_NSmall {
- // For M in [1, 32] and N < 8192
- static_assert(std::is_same());
- using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
- using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
- using TileShape = Shape<_64, _64, _256>;
- using ClusterShape = Shape<_1, _8, _1>;
- using Cutlass3xGemm =
- cutlass_3x_gemm;
-};
-
-} // namespace
-
-template typename Epilogue,
- typename... EpilogueArgs>
-void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same());
- TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
- TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
-
- using Cutlass3xGemmDefault =
- typename sm90_fp8_config_default::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_fp8_config_M64::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_fp8_config_M128::Cutlass3xGemm;
-
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast(64), next_pow_2(m)); // next power of 2
-
- if (mp2 <= 64) {
- // m in [1, 64]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- }
-}
-
-template typename Epilogue,
- typename... EpilogueArgs>
-void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
- torch::Tensor const& b,
- EpilogueArgs&&... args) {
- static_assert(std::is_same());
- TORCH_CHECK(a.dtype() == torch::kInt8);
- TORCH_CHECK(b.dtype() == torch::kInt8);
-
- using Cutlass3xGemmDefault =
- typename sm90_int8_config_default::Cutlass3xGemm;
- using Cutlass3xGemmM128 =
- typename sm90_int8_config_M128::Cutlass3xGemm;
- using Cutlass3xGemmM64 =
- typename sm90_int8_config_M64::Cutlass3xGemm;
- using Cutlass3xGemmM32NBig =
- typename sm90_int8_config_M32_NBig::Cutlass3xGemm;
- using Cutlass3xGemmM32NSmall =
- typename sm90_int8_config_M32_NSmall::Cutlass3xGemm;
-
- uint32_t const n = out.size(1);
- bool const is_small_n = n < 8192;
-
- uint32_t const m = a.size(0);
- uint32_t const mp2 =
- std::max(static_cast(32), next_pow_2(m)); // next power of 2
-
- if (mp2 <= 32) {
- // m in [1, 32]
- if (is_small_n) {
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else {
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- }
- } else if (mp2 <= 64) {
- // m in (32, 64]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else if (mp2 <= 128) {
- // m in (64, 128]
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- } else {
- // m in (128, inf)
- return cutlass_gemm_caller(
- out, a, b, std::forward(args)...);
- }
-}
-
template typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
new file mode 100644
index 0000000000000..d4bc2f0ade50d
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cuh
@@ -0,0 +1,160 @@
+#pragma once
+
+// clang-format will break include orders
+// clang-format off
+#include
+
+#include
+
+#include "cutlass/cutlass.h"
+
+#include "cute/tensor.hpp"
+#include "cute/atom/mma_atom.hpp"
+#include "cutlass/numeric_types.h"
+
+#include "cutlass/gemm/device/gemm_universal_adapter.h"
+#include "cutlass/gemm/kernel/gemm_universal.hpp"
+#include "cutlass/epilogue/collective/collective_builder.hpp"
+#include "cutlass/gemm/collective/collective_builder.hpp"
+
+#include "core/math.hpp"
+#include "cutlass_extensions/common.hpp"
+// clang-format on
+
+/*
+ Epilogues defined in,
+ csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp,
+ must contain a public type named EVTCompute of type Sm90EVT, as well as a
+ static prepare_args function that constructs an EVTCompute::Arguments struct.
+*/
+
+using namespace cute;
+
+namespace vllm {
+
+// A wrapper for the GEMM kernel that is used to guard against compilation on
+// architectures that will never use the kernel. The purpose of this is to
+// reduce the size of the compiled binary.
+// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
+// into code that will be executed on the device where it is defined.
+template
+struct enable_sm90_or_later : Kernel {
+ template
+ CUTLASS_DEVICE void operator()(Args&&... args) {
+#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
+ Kernel::operator()(std::forward(args)...);
+#endif
+ }
+};
+
+template typename Epilogue_,
+ typename TileShape, typename ClusterShape, typename KernelSchedule,
+ typename EpilogueSchedule>
+struct cutlass_3x_gemm {
+ using ElementAB = ElementAB_;
+ using ElementD = ElementD_;
+ using ElementAcc =
+ typename std::conditional, int32_t,
+ float>::type;
+
+ using EpilogueDescriptor =
+ cutlass::epilogue::collective::detail::EpilogueDescriptor<
+ TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD,
+ ElementD, EpilogueSchedule>;
+
+ using Epilogue = Epilogue_;
+
+ using StrideD = Stride, Int<0>>;
+ using ElementC = void;
+ using StrideC = StrideD;
+
+ using EVTCompute = typename Epilogue::EVTCompute;
+
+ using CollectiveEpilogue =
+ typename cutlass::epilogue::collective::CollectiveBuilder<
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAcc, float, ElementC, StrideC, 4, ElementD, StrideD, 4,
+ EpilogueSchedule, EVTCompute>::CollectiveOp;
+
+ static constexpr size_t CEStorageSize =
+ sizeof(typename CollectiveEpilogue::SharedStorage);
+ using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
+ static_cast(CEStorageSize)>;
+
+ // clang-format off
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
+ ElementAB, cutlass::layout::RowMajor, 16,
+ ElementAB, cutlass::layout::ColumnMajor, 16,
+ ElementAcc, TileShape, ClusterShape,
+ Stages,
+ KernelSchedule>::CollectiveOp;
+ // clang-format on
+
+ using KernelType = enable_sm90_or_later, CollectiveMainloop, CollectiveEpilogue,
+ cutlass::gemm::PersistentScheduler>>;
+
+ struct GemmKernel : public KernelType {};
+};
+
+template
+void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& b,
+ EpilogueArgs&&... epilogue_params) {
+ using ElementAB = typename Gemm::ElementAB;
+ using ElementD = typename Gemm::ElementD;
+
+ int32_t m = a.size(0);
+ int32_t n = b.size(1);
+ int32_t k = a.size(1);
+
+ int64_t lda = a.stride(0);
+ int64_t ldb = b.stride(1);
+ int64_t ldc = out.stride(0);
+
+ using StrideA = Stride, int64_t>;
+ using StrideB = Stride, int64_t>;
+ using StrideC = typename Gemm::StrideC;
+
+ StrideA a_stride{lda, Int<1>{}, 0};
+ StrideB b_stride{ldb, Int<1>{}, 0};
+ StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
+
+ using GemmKernel = typename Gemm::GemmKernel;
+ typename GemmKernel::ProblemShape prob_shape{m, n, k, 1};
+
+ auto a_ptr = static_cast(a.data_ptr());
+ auto b_ptr = static_cast(b.data_ptr());
+ typename GemmKernel::MainloopArguments mainloop_args{a_ptr, a_stride, b_ptr,
+ b_stride};
+
+ auto c_ptr = static_cast(out.data_ptr());
+ typename GemmKernel::EpilogueArguments epilogue_args{
+ Gemm::Epilogue::prepare_args(
+ std::forward(epilogue_params)...),
+ c_ptr, c_stride, c_ptr, c_stride};
+
+ typename GemmKernel::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm,
+ prob_shape, mainloop_args, epilogue_args};
+
+ // Launch the CUTLASS GEMM kernel.
+ using GemmOp = cutlass::gemm::device::GemmUniversalAdapter;
+ GemmOp gemm_op;
+ CUTLASS_CHECK(gemm_op.can_implement(args));
+
+ size_t workspace_size = gemm_op.get_workspace_size(args);
+ auto const workspace_options =
+ torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
+ auto workspace = torch::empty(workspace_size, workspace_options);
+
+ auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
+
+ cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
+ CUTLASS_CHECK(status);
+}
+
+} // namespace vllm
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
new file mode 100644
index 0000000000000..f08419b3122b2
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_fp8_dispatch.cuh
@@ -0,0 +1,96 @@
+#pragma once
+
+#include "scaled_mm_c3x.cuh"
+
+/**
+ * This file defines Gemm kernel configurations for SM90 (fp8) based on the Gemm
+ * shape.
+ */
+
+namespace vllm {
+
+template typename Epilogue>
+struct sm90_fp8_config_default {
+ // M in (128, inf)
+ static_assert(std::is_same());
+ using KernelSchedule =
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_128, _128, _128>;
+ using ClusterShape = Shape<_2, _1, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_fp8_config_M128 {
+ // M in (64, 128]
+ static_assert(std::is_same());
+ using KernelSchedule =
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _128, _128>;
+ using ClusterShape = Shape<_2, _1, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_fp8_config_M64 {
+ // M in [1, 64]
+ static_assert(std::is_same());
+ using KernelSchedule =
+ cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _64, _128>;
+ using ClusterShape = Shape<_1, _8, _1>;
+
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue,
+ typename... EpilogueArgs>
+inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ EpilogueArgs&&... args) {
+ static_assert(std::is_same());
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
+ TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
+
+ using Cutlass3xGemmDefault =
+ typename sm90_fp8_config_default::Cutlass3xGemm;
+ using Cutlass3xGemmM64 =
+ typename sm90_fp8_config_M64::Cutlass3xGemm;
+ using Cutlass3xGemmM128 =
+ typename sm90_fp8_config_M128::Cutlass3xGemm;
+
+ uint32_t const m = a.size(0);
+ uint32_t const mp2 =
+ std::max(static_cast(64), next_pow_2(m)); // next power of 2
+
+ if (mp2 <= 64) {
+ // m in [1, 64]
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ } else if (mp2 <= 128) {
+ // m in (64, 128]
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ } else {
+ // m in (128, inf)
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ }
+}
+
+} // namespace vllm
\ No newline at end of file
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
new file mode 100644
index 0000000000000..34e5fd90ba26a
--- /dev/null
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90_int8_dispatch.cuh
@@ -0,0 +1,140 @@
+#pragma once
+
+#include "scaled_mm_c3x.cuh"
+
+/**
+ * This file defines Gemm kernel configurations for SM90 (int8) based on the
+ * Gemm shape.
+ */
+
+namespace vllm {
+
+template typename Epilogue>
+struct sm90_int8_config_default {
+ // For M > 128 and any N
+ static_assert(std::is_same());
+ using KernelSchedule =
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_128, _128, _128>;
+ using ClusterShape = Shape<_2, _1, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_int8_config_M128 {
+ // For M in (64, 128] and any N
+ static_assert(std::is_same());
+ using KernelSchedule =
+ typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _128, _128>;
+ using ClusterShape = Shape<_2, _1, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_int8_config_M64 {
+ // For M in (32, 64] and any N
+ static_assert(std::is_same());
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _64, _256>;
+ using ClusterShape = Shape<_1, _1, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_int8_config_M32_NBig {
+ // For M in [1, 32] and N >= 8192
+ static_assert(std::is_same());
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _128, _256>;
+ using ClusterShape = Shape<_1, _4, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue>
+struct sm90_int8_config_M32_NSmall {
+ // For M in [1, 32] and N < 8192
+ static_assert(std::is_same());
+ using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
+ using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
+ using TileShape = Shape<_64, _64, _256>;
+ using ClusterShape = Shape<_1, _8, _1>;
+ using Cutlass3xGemm =
+ cutlass_3x_gemm;
+};
+
+template typename Epilogue,
+ typename... EpilogueArgs>
+inline void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out,
+ torch::Tensor const& a,
+ torch::Tensor const& b,
+ EpilogueArgs&&... args) {
+ static_assert(std::is_same());
+ TORCH_CHECK(a.dtype() == torch::kInt8);
+ TORCH_CHECK(b.dtype() == torch::kInt8);
+
+ using Cutlass3xGemmDefault =
+ typename sm90_int8_config_default::Cutlass3xGemm;
+ using Cutlass3xGemmM128 =
+ typename sm90_int8_config_M128::Cutlass3xGemm;
+ using Cutlass3xGemmM64 =
+ typename sm90_int8_config_M64::Cutlass3xGemm;
+ using Cutlass3xGemmM32NBig =
+ typename sm90_int8_config_M32_NBig::Cutlass3xGemm;
+ using Cutlass3xGemmM32NSmall =
+ typename sm90_int8_config_M32_NSmall::Cutlass3xGemm;
+
+ uint32_t const n = out.size(1);
+ bool const is_small_n = n < 8192;
+
+ uint32_t const m = a.size(0);
+ uint32_t const mp2 =
+ std::max(static_cast(32), next_pow_2(m)); // next power of 2
+
+ if (mp2 <= 32) {
+ // m in [1, 32]
+ if (is_small_n) {
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ } else {
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ }
+ } else if (mp2 <= 64) {
+ // m in (32, 64]
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ } else if (mp2 <= 128) {
+ // m in (64, 128]
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ } else {
+ // m in (128, inf)
+ return cutlass_gemm_caller(
+ out, a, b, std::forward(args)...);
+ }
+}
+
+} // namespace vllm
\ No newline at end of file
diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
index 97a969cf5e3e0..4f7b6588ef3f7 100644
--- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
@@ -3,6 +3,8 @@
#include
#include
+#include "cutlass_extensions/common.hpp"
+
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
@@ -79,16 +81,6 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
return false;
}
-int32_t get_sm_version_num() {
- int32_t major_capability, minor_capability;
- cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
- 0);
- cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor,
- 0);
- int32_t version_num = major_capability * 10 + minor_capability;
- return version_num;
-}
-
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
diff --git a/csrc/sparse/cutlass/sparse_compressor_c3x.cu b/csrc/sparse/cutlass/sparse_compressor_c3x.cu
new file mode 100644
index 0000000000000..bd53695503241
--- /dev/null
+++ b/csrc/sparse/cutlass/sparse_compressor_c3x.cu
@@ -0,0 +1,165 @@
+// clang-format will break include orders
+// clang-format off
+#include
+
+#if defined CUDA_VERSION && CUDA_VERSION >= 12020
+#include "sparse_scaled_mm_c3x.cuh"
+
+#include "cutlass/numeric_conversion.h"
+#include "cutlass/transform/device/transform_universal_adapter.hpp"
+#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
+#include "cutlass/epilogue/collective/default_epilogue.hpp"
+
+#include "cutlass/util/host_tensor.h"
+#include "cutlass/util/packed_stride.hpp"
+// clang-format on
+
+using namespace cute;
+using namespace vllm;
+
+/// Make A structured sparse by replacing elements with 0 and compress it
+template
+bool cutlass_sparse_compress(torch::Tensor& a_nzs, torch::Tensor& a_meta,
+ torch::Tensor const& a) {
+ // Checks for conformality
+ TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
+ a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
+ TORCH_CHECK(a.dim() == 2)
+ // Check for strides and alignment
+ TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
+ TORCH_CHECK(a.stride(1) == 1)
+
+ int m = a.size(0);
+ int k = a.size(1);
+
+ // Sparse kernel setup; this kernel is not used for matmul,
+ // but just for setting up the compressor utility
+ // A matrix configuration
+ using ElementA = ElementA_;
+ using LayoutTagA = cutlass::layout::RowMajor;
+ constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value;
+ // B matrix configuration
+ using ElementB = ElementA;
+ using LayoutTagB = cutlass::layout::ColumnMajor;
+ constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value;
+ // C/D matrix configuration
+ using ElementC = float;
+ using LayoutTagC = cutlass::layout::ColumnMajor;
+ constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value;
+ // Core kernel configurations
+ using ElementAccumulator = ElementAcc_;
+ using TileShape = Shape<_128, _128, _128>;
+ using TileShapeRef = Shape<_128, _128, _64>;
+ using ClusterShape = Shape<_1, _2, _1>;
+ using KernelSchedule = typename std::conditional<
+ std::is_same_v,
+ cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum,
+ cutlass::gemm::KernelTmaWarpSpecialized>::type;
+
+ using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
+ using ProblemShape = Shape;
+
+ using CollectiveEpilogue =
+ typename cutlass::epilogue::collective::CollectiveBuilder<
+ cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
+ ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
+ ElementAccumulator, ElementAccumulator, ElementC, LayoutTagC,
+ AlignmentC, ElementC, LayoutTagC, AlignmentC,
+ EpilogueSchedule>::CollectiveOp;
+
+ using CollectiveMainloop =
+ typename cutlass::gemm::collective::CollectiveBuilder<
+ cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, ElementA,
+ LayoutTagA, AlignmentA, ElementB, LayoutTagB, AlignmentB,
+ ElementAccumulator, TileShape, ClusterShape,
+ cutlass::gemm::collective::StageCountAutoCarveout(
+ sizeof(typename CollectiveEpilogue::SharedStorage))>,
+ KernelSchedule>::CollectiveOp;
+
+ using GemmKernel =
+ cutlass::gemm::kernel::GemmUniversal;
+
+ using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
+
+ using StrideA = cutlass::gemm::TagToStrideA_t;
+ using StrideE = StrideA;
+
+ using StrideA = Stride, int64_t>;
+
+ // The n (=1) dimension does not matter for the compressor
+ typename GemmKernel::ProblemShape prob_shape{m, 1, k, 1};
+
+ using LayoutA = typename GemmKernel::CollectiveMainloop::LayoutA;
+ using LayoutE = typename GemmKernel::CollectiveMainloop::LayoutE;
+
+ using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
+ using SparseConfig = typename GemmKernel::CollectiveMainloop::SparseConfig;
+
+ // Offline compressor kernel
+ using CompressorUtility =
+ cutlass::transform::kernel::StructuredSparseCompressorUtility<
+ ProblemShape, ElementA, LayoutTagA, SparseConfig>;
+
+ using CompressorKernel =
+ cutlass::transform::kernel::StructuredSparseCompressor<
+ ProblemShape, ElementA, LayoutTagA, SparseConfig,
+ cutlass::arch::Sm90>;
+
+ using Compressor =
+ cutlass::transform::device::TransformUniversalAdapter;
+
+ auto [M, N, K, L] = prob_shape;
+
+ StrideA stride_A;
+ stride_A =
+ cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
+
+ CompressorUtility compressor_utility(prob_shape, stride_A);
+
+ int ME = compressor_utility.get_metadata_m_physical();
+ int KE = compressor_utility.get_metadata_k_physical();
+ int KC = compressor_utility.get_tensorA_k_physical();
+
+ auto a_ptr = static_cast(a.data_ptr());
+
+ auto a_nzs_ptr = static_cast(a_nzs.data_ptr());
+ auto a_meta_ptr = static_cast(
+ a_meta.data_ptr());
+
+ cutlass::KernelHardwareInfo hw_info;
+ hw_info.device_id = 0;
+ hw_info.sm_count =
+ cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
+ hw_info.device_id);
+ typename Compressor::Arguments arguments{
+ prob_shape, {a_ptr, stride_A, a_nzs_ptr, a_meta_ptr}, {hw_info}};
+
+ Compressor compressor_op;
+ size_t workspace_size = Compressor::get_workspace_size(arguments);
+ cutlass::device_memory::allocation workspace(workspace_size);
+
+ CUTLASS_CHECK(compressor_op.can_implement(arguments));
+ CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
+ CUTLASS_CHECK(compressor_op.run());
+ CUDA_CHECK(cudaDeviceSynchronize());
+
+ return true;
+}
+
+bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
+ torch::Tensor const& a) {
+ if (a.dtype() == torch::kBFloat16) {
+ return cutlass_sparse_compress(a_nzs, a_meta,
+ a);
+ } else if (a.dtype() == torch::kFloat16) {
+ return cutlass_sparse_compress(a_nzs, a_meta, a);
+ } else if (a.dtype() == torch::kFloat8_e4m3fn) {
+ return cutlass_sparse_compress(a_nzs, a_meta,
+ a);
+ } else if (a.dtype() == torch::kInt8) {
+ return cutlass_sparse_compress(a_nzs, a_meta, a);
+ }
+ return false;
+}
+#endif
diff --git a/csrc/sparse/cutlass/sparse_compressor_entry.cu b/csrc/sparse/cutlass/sparse_compressor_entry.cu
new file mode 100644
index 0000000000000..3401761c1b703
--- /dev/null
+++ b/csrc/sparse/cutlass/sparse_compressor_entry.cu
@@ -0,0 +1,42 @@
+#include
+
+#include
+#include
+
+#include "cutlass_extensions/common.hpp"
+
+#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
+bool cutlass_sparse_compress_sm90(torch::Tensor& a_nzs, torch::Tensor& a_meta,
+ torch::Tensor const& a);
+#endif
+
+bool cutlass_sparse_compress_entry(torch::Tensor& a_nzs, torch::Tensor& a_meta,
+ torch::Tensor const& a) {
+ // Checks for conformality
+ TORCH_CHECK(a.dim() == 2 && a_meta.dim() == 2 && a_nzs.dim() == 2);
+ TORCH_CHECK(a.size(0) == a_nzs.size(0) && a.size(0) == a_meta.size(0) &&
+ a_nzs.size(1) * 2 == a.size(1) &&
+ a_meta.size(1) * 2 * 4 == a.size(1));
+ // Considering elemsPerMetaElem = 8b / 2b_per_nz = 4
+
+ // Check for strides and alignment
+ TORCH_CHECK(a.stride(1) == 1 && a_nzs.stride(1) == 1 &&
+ a_meta.stride(1) == 1); // Row-major
+ TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
+
+ at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
+ int32_t version_num = get_sm_version_num();
+
+ // Guard against compilation issues for sm90 kernels
+#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
+ if (version_num >= 90) {
+ return cutlass_sparse_compress_sm90(a_nzs, a_meta, a);
+ }
+#endif
+
+ TORCH_CHECK_NOT_IMPLEMENTED(
+ false,
+ "No compiled cutlass_scaled_sparse_mm for a compute capability less than "
+ "CUDA device capability: ",
+ version_num);
+}
diff --git a/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
new file mode 100644
index 0000000000000..6223dc8cca704
--- /dev/null
+++ b/csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
@@ -0,0 +1,303 @@
+// clang-format will break include orders
+// clang-format off
+#include
+
+#if defined CUDA_VERSION && CUDA_VERSION >= 12020
+#include "sparse_scaled_mm_c3x.cuh"
+// clang-format on
+
+using namespace cute;
+using namespace vllm;
+
+template typename Epilogue,
+ typename... EpilogueArgs>
+void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
+ torch::Tensor const& bt_nzs,
+ torch::Tensor const& bt_meta,
+ EpilogueArgs&&... args) {
+ static_assert(std::is_same());
+ TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
+ TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
+ TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
+
+ using Cutlass3xGemmDefault =
+ typename sm90_config_default::Cutlass3xGemm;
+ using Cutlass3xGemmM64 =
+ typename sm90_fp8_config_M64::Cutlass3xGemm;
+ using Cutlass3xGemmM128 =
+ typename sm90_fp8_config_M128::Cutlass3xGemm;
+ using Cutlass3xGemmM256 =
+ typename sm90_fp8_config_M256::Cutlass3xGemm;
+ using Cutlass3xGemmM512 =
+ typename sm90_fp8_config_M512::Cutlass3xGemm;
+
+ using Cutlass3xGemm1 =
+ typename sm90_fp8_config_1::Cutlass3xGemm;
+ using Cutlass3xGemm2 =
+ typename sm90_fp8_config_2::Cutlass3xGemm;
+ using Cutlass3xGemm3 =
+ typename sm90_fp8_config_3::Cutlass3xGemm;
+ using Cutlass3xGemm4 =
+ typename sm90_fp8_config_4::Cutlass3xGemm;
+ using Cutlass3xGemm5 =
+ typename sm90_fp8_config_5::Cutlass3xGemm;
+ using Cutlass3xGemm6 =
+ typename sm90_fp8_config_6::Cutlass3xGemm;
+ using Cutlass3xGemm7 =
+ typename sm90_fp8_config_7::Cutlass3xGemm;
+ using Cutlass3xGemm8 =
+ typename sm90_fp8_config_8::Cutlass3xGemm;
+
+ uint32_t const n = bt_nzs.size(0);
+ uint32_t const m = a.size(0); // Batch size
+ uint32_t const mp2 =
+ std::max(static_cast(64), next_pow_2(m)); // next power of 2
+
+ if (mp2 <= 64) {
+ if (n == 28672) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 4096 || n == 6144) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ }
+ } else if (mp2 <= 128) {
+ if (n == 4096) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 28672) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 6144) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ }
+ } else if (mp2 <= 256) {
+ if (n == 4096) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 28672) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 6144) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ }
+ } else {
+ if (n == 6144 || n == 28672) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (n == 4096) {
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ }
+ }
+
+ // Otherwise the default heuristic
+ if (mp2 <= 64) {
+ // n in [1, 64]
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (mp2 <= 128) {
+ // n in (64, 128]
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else if (mp2 <= 256) {
+ // n in (128, 256]
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ } else {
+ // n in (256, inf)
+ return cutlass_sparse_gemm_caller(
+ out, a, bt_nzs, bt_meta, std::forward(args)...);
+ }
+}
+
+template