From 153b414e835ead40017b00b3049bfb657a7748fa Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 24 Jan 2025 19:22:39 +0800 Subject: [PATCH 01/16] minor: sync flashinfer and add turbomind as 3rdparty (#3105) --- .gitmodules | 3 +++ sgl-kernel/3rdparty/flashinfer | 2 +- sgl-kernel/3rdparty/turbomind | 1 + sgl-kernel/developer_guide.md | 1 + 4 files changed, 6 insertions(+), 1 deletion(-) create mode 160000 sgl-kernel/3rdparty/turbomind diff --git a/.gitmodules b/.gitmodules index ed7603bfd3c..97f3421449d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "sgl-kernel/3rdparty/flashinfer"] path = sgl-kernel/3rdparty/flashinfer url = https://github.com/flashinfer-ai/flashinfer.git +[submodule "sgl-kernel/3rdparty/turbomind"] + path = sgl-kernel/3rdparty/turbomind + url = https://github.com/InternLM/turbomind diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 93e1a2634e2..2d03ed7c01a 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 93e1a2634e22355b0856246b032b285ad1d1da6b +Subproject commit 2d03ed7c01aefd946c8a5781df9e59c0380116d4 diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind new file mode 160000 index 00000000000..0c9d0c724a9 --- /dev/null +++ b/sgl-kernel/3rdparty/turbomind @@ -0,0 +1 @@ +Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9 diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md index f41ce071e0b..91e93ff7508 100644 --- a/sgl-kernel/developer_guide.md +++ b/sgl-kernel/developer_guide.md @@ -19,6 +19,7 @@ Third-party libraries: - [CCCL](https://github.com/NVIDIA/cccl) - [CUTLASS](https://github.com/NVIDIA/cutlass) - [FlashInfer](https://github.com/flashinfer-ai/flashinfer) +- [TurboMind](https://github.com/InternLM/turbomind) ### Kernel Development From 685a5738a7b09faacc786e77f2a2ecfb5c9d6cea Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 24 Jan 2025 03:59:47 -0800 Subject: [PATCH 02/16] Allow local cutlass directory to be used in sgl-kernel build (#3037) --- sgl-kernel/setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index d60167435c4..cf3c6a56303 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,6 +39,8 @@ def _get_version(): cutlass = root / "3rdparty" / "cutlass" +cutlass_default = root / "3rdparty" / "cutlass" +cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" include_dirs = [ cutlass.resolve() / "include", From 4505a43614ba7826a192c122f749b99e170966b5 Mon Sep 17 00:00:00 2001 From: Adarsh Shirawalmath <114558126+adarshxs@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:30:20 +0530 Subject: [PATCH 03/16] [Docs] minor update for phi-3 and phi-4 (#3096) --- docs/references/supported_models.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 60551b2c1da..0a00ad0c8a1 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -28,6 +28,7 @@ - XVERSE / XVERSE MoE - SmolLM - GLM-4 +- Phi-3 / Phi-4 - Phi-3-Small - IBM Granite 3 From 04f0b4cbeff5f1d5e511a1ce5cc2f8cdfa0fc1fc Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 24 Jan 2025 20:10:35 +0800 Subject: [PATCH 04/16] minor: update sgl-kernel setup (#3107) --- sgl-kernel/setup.py | 26 +++--- .../src/sgl-kernel/csrc/fused_add_rms_norm.cu | 92 +++++++++++++++++++ 2 files changed, 103 insertions(+), 15 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index cf3c6a56303..56c5b1bb56b 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -38,10 +38,10 @@ def _get_version(): return line.split("=")[1].strip().strip('"') -cutlass = root / "3rdparty" / "cutlass" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" +turbomind = root / "3rdparty" / "turbomind" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -49,6 +49,8 @@ def _get_version(): flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", + turbomind.resolve(), + turbomind.resolve() / "src", ] nvcc_flags = [ "-DNDEBUG", @@ -63,6 +65,11 @@ def _get_version(): "-use_fast_math", "-DFLASHINFER_ENABLE_F16", ] +nvcc_flags_fp8 = [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", +] sources = [ "src/sgl-kernel/csrc/trt_reduce_internal.cu", @@ -73,6 +80,7 @@ def _get_version(): "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", + "src/sgl-kernel/csrc/fused_add_rms_norm.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/group_gemm.cu", @@ -92,13 +100,7 @@ def _get_version(): nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if sm_version >= 90: - nvcc_flags.extend( - [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - ) + nvcc_flags.extend(nvcc_flags_fp8) if sm_version >= 80: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") else: @@ -107,13 +109,7 @@ def _get_version(): nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") if enable_fp8: - nvcc_flags.extend( - [ - "-DFLASHINFER_ENABLE_FP8", - "-DFLASHINFER_ENABLE_FP8_E4M3", - "-DFLASHINFER_ENABLE_FP8_E5M2", - ] - ) + nvcc_flags.extend(nvcc_flags_fp8) if enable_bf16: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu new file mode 100644 index 00000000000..73406158667 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm.cu @@ -0,0 +1,92 @@ +// Adapted from +// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu + +#include +#include + +#include + +using namespace turbomind; + +template +__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, T* __restrict__ hidden_states, + const T* __restrict__ weights, const T* __restrict__ bias, int dims, int num, + float eps, float inv_dims) { + const int ti = blockIdx.x; + const int di = threadIdx.x * vec_size; + + if (ti >= num) { + return; + } + + residual += dims * ti; + hidden_states += dims * ti; + + Array accum{}; + + Array r_vec; + Array h_vec; + Array b_vec; + + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Load(h_vec, &hidden_states[i]); + + using namespace ops; + r_vec = r_vec + h_vec; + + if (bias) { + Ldg(b_vec, &bias[i]); + r_vec = r_vec + b_vec; + } + + Store(&residual[i], r_vec); + + Array tmp = cast(r_vec); + + accum = accum + tmp * tmp; + } + + float sum{}; + PRAGMA_UNROLL + for (int i = 0; i < vec_size; ++i) { + sum += accum[i]; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + sum = BlockReduce{temp_storage}.Sum(sum); + + __shared__ float shared_sum; + + if (threadIdx.x == 0) { + shared_sum = rsqrtf(sum * inv_dims + eps); + } + + __syncthreads(); + + sum = shared_sum; + + Array w_vec; + for (int i = di; i < dims; i += block_dim * vec_size) { + Load(r_vec, &residual[i]); + Ldg(w_vec, &weights[i]); + PRAGMA_UNROLL + for (int c = 0; c < vec_size; ++c) { + r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c]; + } + Store(&hidden_states[i], r_vec); + } +} + +template +void invokeBiasResidualRMSNorm(T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num, + float eps, cudaStream_t st) { + constexpr int vec_size = 16 / sizeof(T); + constexpr int threads = 512; + const int blocks = num; + + BiasResidualRMSNormKernel + <<>>(residual, hidden_states, weights, bias, dims, num, eps, 1.f / dims); +} From a22f60a313818678ba7455088833705be694c32f Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Fri, 24 Jan 2025 22:30:30 +0800 Subject: [PATCH 05/16] Add workflow for sgl-kernel cu118 release (#3109) --- .github/workflows/release-whl-kernel.yml | 59 ++++++++++++++++++++++++ sgl-kernel/build.sh | 8 +++- 2 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/release-whl-kernel.yml diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml new file mode 100644 index 00000000000..b49da1feb9c --- /dev/null +++ b/.github/workflows/release-whl-kernel.yml @@ -0,0 +1,59 @@ +name: Release SGLang Kernel Wheel (cu118) + +on: + workflow_dispatch: + inputs: + tag_name: + required: true + type: string + +jobs: + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + cuda-version: ['11.8'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + release: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ inputs.tag_name }} + repository: sgl-project/whl + token: ${{ secrets.WHL_TOKEN }} + files: | + sgl-kernel/dist/* diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index c899224818e..1caa892bc84 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -4,6 +4,12 @@ PYTHON_VERSION=$1 CUDA_VERSION=$2 PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.} +if (( ${CUDA_VERSION%.*} < 12 )); then + ENABLE_SM90A=0 +else + ENABLE_SM90A=1 +fi + docker run --rm \ -v "$(pwd)":/sgl-kernel \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \ @@ -13,7 +19,7 @@ docker run --rm \ export CUDA_VERSION=${CUDA_VERSION} && \ export SGL_KERNEL_ENABLE_BF16=1 && \ export SGL_KERNEL_ENABLE_FP8=1 && \ - export SGL_KERNEL_ENABLE_SM90A=1 && \ + export SGL_KERNEL_ENABLE_SM90A=${ENABLE_SM90A} && \ mkdir -p /usr/lib/x86_64-linux-gnu/ && \ ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ From 665e5e85f6d7a3a153d852cf11f73ba2f892fdff Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 25 Jan 2025 02:03:01 +0800 Subject: [PATCH 06/16] Add step to update sgl-kernel whl index (#3110) --- .github/workflows/release-whl-kernel.yml | 19 +++++++++++++++++++ scripts/update_kernel_whl_index.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+) create mode 100644 scripts/update_kernel_whl_index.py diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml index b49da1feb9c..1b2efaad77d 100644 --- a/.github/workflows/release-whl-kernel.yml +++ b/.github/workflows/release-whl-kernel.yml @@ -42,6 +42,8 @@ jobs: needs: build-wheels runs-on: ubuntu-latest steps: + - uses: actions/checkout@v4 + - name: Download artifacts uses: actions/download-artifact@v4 with: @@ -57,3 +59,20 @@ jobs: token: ${{ secrets.WHL_TOKEN }} files: | sgl-kernel/dist/* + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Update wheel index + run: python3 scripts/update_kernel_whl_index.py + + - name: Push wheel index + run: | + cd sgl-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl index" + git push diff --git a/scripts/update_kernel_whl_index.py b/scripts/update_kernel_whl_index.py new file mode 100644 index 00000000000..bcd92ef64e9 --- /dev/null +++ b/scripts/update_kernel_whl_index.py @@ -0,0 +1,16 @@ +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py + +import hashlib +import pathlib +import re + +for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")): + with open(path, "rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0] + index_dir = pathlib.Path(f"sgl-whl/cu118") + index_dir.mkdir(exist_ok=True) + base_url = "https://github.com/sgl-project/whl/releases/download" + full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" + with (index_dir / "index.html").open("a") as f: + f.write(f'{path.name}
\n') From 5d9d15e70f7e73223a3d2baf3851b95a9d5356f0 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sat, 25 Jan 2025 16:52:17 +0800 Subject: [PATCH 07/16] support fp32 in sampling_scaling_penalties kernel (#3121) --- .../csrc/sampling_scaling_penalties.cu | 3 +-- sgl-kernel/src/sgl-kernel/csrc/utils.h | 18 ++++++++++++++++++ .../tests/test_sampling_scaling_penalties.py | 10 +++++++--- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu index 2a9de4d9f71..18beb86445f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -1,7 +1,6 @@ #include #include #include -#include #include #include @@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(logits.scalar_type(), scalar_t, [&] { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] { uint32_t vec_size = 16 / sizeof(scalar_t); const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size); sampling_scaling_penalties_kernel<<>>( diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.h b/sgl-kernel/src/sgl-kernel/csrc/utils.h index 2fed2d60c03..ed802d4fdef 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.h +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -44,3 +45,20 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() diff --git a/sgl-kernel/tests/test_sampling_scaling_penalties.py b/sgl-kernel/tests/test_sampling_scaling_penalties.py index 6194c761710..a56eca866b2 100644 --- a/sgl-kernel/tests/test_sampling_scaling_penalties.py +++ b/sgl-kernel/tests/test_sampling_scaling_penalties.py @@ -2,10 +2,14 @@ import torch from sgl_kernel import sampling_scaling_penalties +batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] +vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] +dtypes = [torch.float32, torch.half, torch.bfloat16] -@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 65]) -@pytest.mark.parametrize("vocab_size", [2048, 4096, 8192, 16384, 32768, 32767]) -@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) + +@pytest.mark.parametrize("batch_size", batch_sizes) +@pytest.mark.parametrize("vocab_size", vocab_sizes) +@pytest.mark.parametrize("dtype", dtypes) def test_sampling_scaling_penalties(batch_size, vocab_size, dtype): device = torch.device("cuda") rtol = 1e-3 From 98522149ff422d4700bf43dc6c944ee70cf2b516 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Sat, 25 Jan 2025 18:26:41 +0800 Subject: [PATCH 08/16] mirror fix for custom allreduce (#3124) --- sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index 006c3200dd1..8bdb5012543 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -160,7 +160,7 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag } template -static __global__ void oneShotAllReduceKernel(AllReduceParams params) { +static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start four blocks. // The message is partitioned into chunks as detailed below: // message From 14e754a868619b5099688d303667d09d2ef3724c Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 25 Jan 2025 20:43:02 +0800 Subject: [PATCH 09/16] chore: bump v0.0.2.post17 for sgl-kernel (#3125) --- sgl-kernel/3rdparty/flashinfer | 2 +- sgl-kernel/Makefile | 7 +++++-- sgl-kernel/pyproject.toml | 2 +- sgl-kernel/version.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 2d03ed7c01a..6e6f38d3534 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 2d03ed7c01aefd946c8a5781df9e59c0380116d4 +Subproject commit 6e6f38d3534994c34b2c6b09b5b45c8a7b92ffd2 diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index c7641bb5fee..1384f1bcd81 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -1,4 +1,4 @@ -.PHONY: tree ln submodule install build clean test format +.PHONY: tree ln submodule install build clean rebuild test format tree: @tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist" @@ -13,11 +13,14 @@ install: submodule @pip install -e . build: submodule - @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps clean: @rm -rf build dist *.egg-info +rebuild: clean submodule build + @echo "Succeed to rebuild" + test: @find tests -name "test_*.py" | xargs -n 1 python3 diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 0032c369d94..582e67f4613 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post16" +version = "0.0.2.post17" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 5a127146bb5..ad3ff8af944 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.2.post16" +__version__ = "0.0.2.post17" From 3cab5f71eaff5baf4f1d033371d06e2262a396d0 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 25 Jan 2025 21:37:48 +0800 Subject: [PATCH 10/16] speedup pr test for sgl-kernel (#3126) --- .github/workflows/pr-test-sgl-kernel.yml | 43 +++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index aea60969719..7b58052085b 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -30,20 +30,55 @@ jobs: clangFormatVersion: 16 style: file + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.10'] + cuda-version: ['12.4'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + unit-test: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + needs: build-wheels runs-on: 1-gpu-runner steps: - uses: actions/checkout@v4 + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + - name: Install run: | pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 pip3 uninstall sgl-kernel -y || true - find . -name index.lock -delete - cd sgl-kernel - git submodule deinit --all --force && git submodule sync --recursive && git submodule update --init --force --recursive - pip3 install . + pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel - name: Run test From 67ad4338e1016ff2aa31dbde7dd48432859eb6e5 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 25 Jan 2025 23:14:35 +0800 Subject: [PATCH 11/16] Update tag name for whl release (#3127) --- .github/workflows/release-whl-kernel.yml | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml index 1b2efaad77d..08a820c2aab 100644 --- a/.github/workflows/release-whl-kernel.yml +++ b/.github/workflows/release-whl-kernel.yml @@ -4,8 +4,12 @@ on: workflow_dispatch: inputs: tag_name: - required: true type: string + push: + branches: + - main + paths: + - sgl-kernel/version.py jobs: build-wheels: @@ -51,10 +55,20 @@ jobs: merge-multiple: true pattern: wheel-* + - name: Set tag name + id: set_tag_name + run: | + if [ -z "${{ inputs.tag_name }}" ]; then + TAG_NAME="v$(cat sgl-kernel/version.py | cut -d'"' -f2)" + echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT + else + echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT + fi + - name: Release uses: softprops/action-gh-release@v2 with: - tag_name: ${{ inputs.tag_name }} + tag_name: ${{ steps.set_tag_name.outputs.tag_name }} repository: sgl-project/whl token: ${{ secrets.WHL_TOKEN }} files: | From c23d5706f4148afc4e7a09d305e8508f4ee7bd0d Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 25 Jan 2025 23:57:09 +0800 Subject: [PATCH 12/16] Update whl index path (#3128) --- scripts/update_kernel_whl_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/update_kernel_whl_index.py b/scripts/update_kernel_whl_index.py index bcd92ef64e9..a42969641f5 100644 --- a/scripts/update_kernel_whl_index.py +++ b/scripts/update_kernel_whl_index.py @@ -8,7 +8,7 @@ with open(path, "rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0] - index_dir = pathlib.Path(f"sgl-whl/cu118") + index_dir = pathlib.Path(f"sgl-whl/cu118/sgl-kernel") index_dir.mkdir(exist_ok=True) base_url = "https://github.com/sgl-project/whl/releases/download" full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" From 896c07441ec12a3ff1b71e74905ba436f0f76501 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sun, 26 Jan 2025 00:00:13 +0800 Subject: [PATCH 13/16] update installation doc for sgl-kernel (#3129) --- .github/workflows/pr-test-sgl-kernel.yml | 2 +- sgl-kernel/README.md | 16 +++++++++++++++- sgl-kernel/pyproject.toml | 2 +- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index 7b58052085b..26b921eee33 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -31,7 +31,7 @@ jobs: style: file build-wheels: - if: github.repository == 'sgl-project/sglang' + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: ubuntu-latest strategy: matrix: diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 857cae366d8..0572f9758ab 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -1,5 +1,19 @@ # SGL Kernel -Kernel Library for SGLang +[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang [![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel) + +## Installation + +For CUDA 11.8: + +```bash +pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118 +``` + +For CUDA 12.1 or CUDA 12.4: + +```bash +pip3 install sgl-kernel +``` diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 582e67f4613..b23c302b564 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -17,7 +17,7 @@ classifiers = [ dependencies = [] [project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" +"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools] From 9286740eff9b735a005e14cf5dfae986c75e3533 Mon Sep 17 00:00:00 2001 From: yinfan98 <1106310035@qq.com> Date: Sun, 26 Jan 2025 02:55:08 +0800 Subject: [PATCH 14/16] feat: refactor sgl-kernel and use TORCH_LIBRARY instead of PYBIND11_MODULE for custom ops (#3130) Co-authored-by: yinfan.1024 Co-authored-by: yinfan98 <1106110035@qq.com> Co-authored-by: Yineng Zhang --- sgl-kernel/developer_guide.md | 11 +- sgl-kernel/setup.py | 11 +- .../sgl_kernels_ops.h} | 72 ++++------- .../{csrc => include}/trt_reduce_internal.cuh | 0 .../src/sgl-kernel/{csrc => include}/utils.h | 3 + sgl-kernel/src/sgl-kernel/ops/__init__.py | 93 ++++++-------- sgl-kernel/src/sgl-kernel/torch_extension.cc | 119 ++++++++++++++++++ 7 files changed, 198 insertions(+), 111 deletions(-) rename sgl-kernel/src/sgl-kernel/{csrc/sgl_kernel_ops.cu => include/sgl_kernels_ops.h} (65%) rename sgl-kernel/src/sgl-kernel/{csrc => include}/trt_reduce_internal.cuh (100%) rename sgl-kernel/src/sgl-kernel/{csrc => include}/utils.h (98%) create mode 100644 sgl-kernel/src/sgl-kernel/torch_extension.cc diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md index 91e93ff7508..26b68535c03 100644 --- a/sgl-kernel/developer_guide.md +++ b/sgl-kernel/developer_guide.md @@ -26,10 +26,11 @@ Third-party libraries: Steps to add a new kernel: 1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) -2. Expose interface in [csrc/sgl_kernel_ops.cu](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu) with pybind11 -3. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) -4. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) -5. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source +2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h) +3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) +4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source ### Build & Install @@ -37,8 +38,6 @@ Development build: ```bash make build -pip3 install dist/*whl --force-reinstall --no-deps -# Or use: make install (runs pip install -e .) ``` ### Testing & Benchmarking diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 56c5b1bb56b..95b040fe185 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -38,6 +38,7 @@ def _get_version(): return line.split("=")[1].strip().strip('"') +operator_namespace = "sgl_kernels" cutlass_default = root / "3rdparty" / "cutlass" cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" @@ -45,15 +46,19 @@ def _get_version(): include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", + root / "src" / "sgl-kernel" / "include", root / "src" / "sgl-kernel" / "csrc", flashinfer.resolve() / "include", flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "csrc", + "cublas", + "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", ] nvcc_flags = [ "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", "-O3", "-Xcompiler", "-fPIC", @@ -72,13 +77,13 @@ def _get_version(): ] sources = [ + "src/sgl-kernel/torch_extension.cc", "src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/fused_add_rms_norm.cu", "3rdparty/flashinfer/csrc/activation.cu", @@ -125,7 +130,7 @@ def _get_version(): pass cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python", "cuda"] +libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] ext_modules = [ @@ -139,6 +144,7 @@ def _get_version(): }, libraries=libraries, extra_link_args=extra_link_args, + py_limited_api=True, ), ] @@ -149,6 +155,7 @@ def _get_version(): package_dir={"": "src"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) _update_wheel_platform_tag() diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h similarity index 65% rename from sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu rename to sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 876d62b7eb3..91e350895c2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -1,7 +1,25 @@ +#pragma once +#include +#include + #include #include "utils.h" +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + // trt_reduce using fptr_t = int64_t; fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, @@ -67,9 +85,18 @@ void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at: int64_t cuda_stream); // top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, unsigned int top_k_val, int64_t cuda_stream); +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +// wrapper for binding +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} + // top p renorm probs void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, double top_p_val, int64_t cuda_stream); @@ -84,48 +111,3 @@ void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_sample void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, std::optional maybe_top_p_arr, double top_p_val, bool deterministic, int64_t cuda_stream); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // trt_reduce - m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); - m.def("dispose", &dispose, "dispose custom allreduce meta"); - m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); - m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, "custom all reduce get graph ipc meta"); - m.def("register_graph_buffers", ®ister_graph_buffers, "custom all reduce register graph buffers"); - // moe_align_block_size - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); - // sampling_scaling_penalties - m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); - // int8_scaled_mm - m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); - // lightning_attention_decode - m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)"); - // rotary embedding - m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); - // rms norm - m.def("rmsnorm", &rmsnorm, "RMSNorm (CUDA)"); - // fused rms norm - m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused Add RMSNorm (CUDA)"); - // gemma rms norm - m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma RMSNorm (CUDA)"); - // fused gemma rms norm - m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, "Gemma Fused Add RMSNorm (CUDA)"); - // silu and mul - m.def("silu_and_mul", &silu_and_mul, "Silu and Mul (CUDA)"); - // gelu tanh and mul - m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Gelu Tanh and Mul (CUDA)"); - // gelu and mul - m.def("gelu_and_mul", &gelu_and_mul, "Gelu and Mul (CUDA)"); - // bmm fp8 - m.def("bmm_fp8", &bmm_fp8, "BMM FP8 (CUDA)"); - // min p sampling from probs - m.def("min_p_sampling_from_probs", &min_p_sampling_from_probs, "Min P Sampling From Probs (CUDA)"); - // top k renorm probs - m.def("top_k_renorm_probs", &top_k_renorm_probs, "Top K Renorm Probs (CUDA)"); - // top p renorm probs - m.def("top_p_renorm_probs", &top_p_renorm_probs, "Top P Renorm Probs (CUDA)"); - // top k top p sampling from probs - m.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs, "Top K Top P Sampling From Probs (CUDA)"); - // top p sampling from probs - m.def("top_p_sampling_from_probs", &top_p_sampling_from_probs, "Top P Sampling From Probs (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh similarity index 100% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh rename to sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h similarity index 98% rename from sgl-kernel/src/sgl-kernel/csrc/utils.h rename to sgl-kernel/src/sgl-kernel/include/utils.h index ed802d4fdef..1cca35d5cd7 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.h +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,9 +1,12 @@ #pragma once +#include #include #include #include +#include "sgl_kernels_ops.h" + struct cuda_error : public std::runtime_error { /** * @brief Constructs a `cuda_error` object with the given `message`. diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index cd69eb3c249..3a21ced875a 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,41 +1,8 @@ +import os from typing import Optional, Tuple, Union +import sgl_kernel.ops._kernels import torch -from sgl_kernel.ops._kernels import all_reduce as _all_reduce -from sgl_kernel.ops._kernels import bmm_fp8 as _bmm_fp8 -from sgl_kernel.ops._kernels import dispose as _dispose -from sgl_kernel.ops._kernels import fused_add_rmsnorm as _fused_add_rmsnorm -from sgl_kernel.ops._kernels import gelu_and_mul as _gelu_and_mul -from sgl_kernel.ops._kernels import gelu_tanh_and_mul as _gelu_tanh_and_mul -from sgl_kernel.ops._kernels import gemma_fused_add_rmsnorm as _gemma_fused_add_rmsnorm -from sgl_kernel.ops._kernels import gemma_rmsnorm as _gemma_rmsnorm -from sgl_kernel.ops._kernels import ( - get_graph_buffer_ipc_meta as _get_graph_buffer_ipc_meta, -) -from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar -from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm -from sgl_kernel.ops._kernels import ( - lightning_attention_decode as _lightning_attention_decode, -) -from sgl_kernel.ops._kernels import ( - min_p_sampling_from_probs as _min_p_sampling_from_probs, -) -from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size -from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers -from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm -from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding -from sgl_kernel.ops._kernels import ( - sampling_scaling_penalties as _sampling_scaling_penalties, -) -from sgl_kernel.ops._kernels import silu_and_mul as _silu_and_mul -from sgl_kernel.ops._kernels import top_k_renorm_probs as _top_k_renorm_probs -from sgl_kernel.ops._kernels import ( - top_k_top_p_sampling_from_probs as _top_k_top_p_sampling_from_probs, -) -from sgl_kernel.ops._kernels import top_p_renorm_probs as _top_p_renorm_probs -from sgl_kernel.ops._kernels import ( - top_p_sampling_from_probs as _top_p_sampling_from_probs, -) from sgl_kernel.ops.utils import ( _get_cache_buf, _get_cuda_stream, @@ -46,25 +13,25 @@ def init_custom_reduce( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ): - return _init_custom_ar( + return torch.ops.sgl_kernels.init_custom_ar( rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out ) def custom_dispose(fa): - _dispose(fa) + torch.ops.sgl_kernels.dispose(fa) def custom_reduce(fa, inp, out): - _all_reduce(fa, inp, out) + torch.ops.sgl_kernels.all_reduce(fa, inp, out) def get_graph_buffer_ipc_meta(fa): - return _get_graph_buffer_ipc_meta(fa) + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) def register_graph_buffers(fa, handles, offsets): - _register_graph_buffers(fa, handles, offsets) + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) def moe_align_block_size( @@ -77,7 +44,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - _moe_align_block_size( + torch.ops.sgl_kernels.moe_align_block_size( topk_ids, num_experts, block_size, @@ -90,11 +57,11 @@ def moe_align_block_size( def sampling_scaling_penalties(logits, scaling_penalties): - return _sampling_scaling_penalties(logits, scaling_penalties) + return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _int8_scaled_mm( + return torch.ops.sgl_kernels.int8_scaled_mm( mat_a, mat_b, scales_a, @@ -105,11 +72,15 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): - _lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + torch.ops.sgl_kernels.lightning_attention_decode( + q, k, v, past_kv, slope, output, new_kv + ) def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): - return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) + return torch.ops.sgl_kernels.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -123,7 +94,7 @@ def rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) return out @@ -131,7 +102,9 @@ def fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) def gemma_rmsnorm( @@ -143,7 +116,9 @@ def gemma_rmsnorm( with input.device as device: if out is None: out = torch.empty_like(input) - _gemma_rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gemma_rmsnorm( + out, input, weight, eps, _get_cuda_stream(device) + ) return out @@ -151,7 +126,9 @@ def gemma_fused_add_rmsnorm( input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 ) -> None: with input.device as device: - _gemma_fused_add_rmsnorm(input, residual, weight, eps, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: @@ -176,7 +153,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _silu_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device)) return out @@ -192,7 +169,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te dtype=input.dtype, ) with input.device as device: - _gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) return out @@ -208,7 +185,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: dtype=input.dtype, ) with input.device as device: - _gelu_and_mul(out, input, _get_cuda_stream(device)) + torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device)) return out @@ -222,7 +199,7 @@ def _bmm_fp8_internal( ) -> None: with A.device as device: cublas_handle = torch.cuda.current_blas_handle() - _bmm_fp8( + torch.ops.sgl_kernels.bmm_fp8( A, B, D, @@ -262,7 +239,7 @@ def _top_k_renorm_probs_internal( probs = probs.float() maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None renorm_probs = torch.empty_like(probs) - _top_k_renorm_probs( + torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( probs, renorm_probs, maybe_top_k_arr, @@ -293,7 +270,7 @@ def _top_p_renorm_probs_internal( maybe_top_p_arr.float() if maybe_top_p_arr is not None else None ) renorm_probs = torch.empty_like(probs) - _top_p_renorm_probs( + torch.ops.sgl_kernels.top_p_renorm_probs( probs, renorm_probs, maybe_top_p_arr, @@ -328,7 +305,7 @@ def _top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - _top_p_sampling_from_probs( + torch.ops.sgl_kernels.top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -374,7 +351,7 @@ def _top_k_top_p_sampling_from_probs_internal( ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) success = torch.empty(probs.size(0), dtype=torch.bool, device=device) - _top_k_top_p_sampling_from_probs( + torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( probs, uniform_samples, samples, @@ -432,7 +409,7 @@ def _min_p_sampling_from_probs_internal( maybe_min_p_arr.float() if maybe_min_p_arr is not None else None ) samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) - _min_p_sampling_from_probs( + torch.ops.sgl_kernels.min_p_sampling_from_probs( probs, uniform_samples, samples, diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc new file mode 100644 index 00000000000..f8a061c15d5 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -0,0 +1,119 @@ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // trt_reduce + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("dispose", &dispose); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta); + + m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); + + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // sampling_scaling_penalties + m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor"); + m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties); + + // int8_scaled_mm + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + // lightning_attention_decode + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + + // rotary embedding + m.def( + "rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool " + "is_neox) -> ()"); + m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + + // rms norm + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + // fused rms norm + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm); + + // gemma rms norm + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + // silu and mul + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // gelu tanh and mul + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // gelu and mul + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // bmm fp8 + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); + + // min p sampling from probs + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + // top k renorm probs + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + // top p renorm probs + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + // top k top p sampling from probs + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + // top p sampling from probs + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); +} + +REGISTER_EXTENSION(_kernels) From da6f8081f6bc59f56ac773ded42e16b4043a93a5 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 25 Jan 2025 17:43:39 -0800 Subject: [PATCH 15/16] Fix CI tests (#3132) --- .github/workflows/pr-test.yml | 2 ++ test/srt/test_bench_serving.py | 12 ++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index c5eeeee3c14..998a12e75d8 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -43,6 +43,8 @@ jobs: - name: Run test timeout-minutes: 10 + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | cd test/lang python3 run_suite.py --suite per-commit diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index b55260f71a6..8233438fcaf 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -49,7 +49,7 @@ def test_offline_throughput_non_stream_small_batch_size(self): ) # There is a regression with torch 2.5 # This number was 950 for torch 2.4 - self.assertGreater(res["output_throughput"], 850) + self.assertGreater(res["output_throughput"], 1000) def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -114,7 +114,7 @@ def test_offline_throughput_default_fp8(self): f"### test_offline_throughput_default_fp8\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 3850) + self.assertGreater(res["output_throughput"], 3900) def test_online_latency_default(self): res = run_bench_serving( @@ -129,7 +129,7 @@ def test_online_latency_default(self): f"### test_online_latency_default\n" f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 12000) + self.assertLess(res["median_e2e_latency_ms"], 11000) self.assertLess(res["median_ttft_ms"], 86) self.assertLess(res["median_itl_ms"], 10) @@ -161,7 +161,7 @@ def test_online_latency_eagle(self): f"### test_online_latency_eagle\n" f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 10000) + self.assertLess(res["median_e2e_latency_ms"], 450) def test_moe_offline_throughput_default(self): res = run_bench_serving( @@ -176,7 +176,7 @@ def test_moe_offline_throughput_default(self): f"### test_moe_offline_throughput_default\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -191,7 +191,7 @@ def test_moe_offline_throughput_without_radix_cache(self): f"### test_moe_offline_throughput_without_radix_cache\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) if __name__ == "__main__": From 27acf63bbd37eeb82231eca611a9d2947dc74ac6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 25 Jan 2025 18:27:33 -0800 Subject: [PATCH 16/16] Use torch.compile for scaling penalty (#3133) --- .../benchmark_deepseekv3_moe_align_blocks.py | 1 - .../penalizers/repetition_penalty.py | 24 ++++++++----------- .../srt/sampling/sampling_batch_info.py | 18 ++++---------- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index d00f4985ad2..e2c4d8d3506 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -1,6 +1,5 @@ import argparse import itertools -import time import torch import triton diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index fcd5ff71c23..0f714c54806 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,11 +3,16 @@ import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import get_compiler_backend -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits > 0, + logits / scaling_penalties, + logits * scaling_penalties, + ) class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -61,16 +66,7 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - if is_cuda: - return sampling_scaling_penalties( - logits, self.cumulated_repetition_penalties - ) - else: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a27ff1ad2a3..9521a34f4f6 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -7,14 +7,11 @@ import torch -from sglang.srt.utils import is_cuda_available - -is_cuda = is_cuda_available() -if is_cuda: - from sgl_kernel import sampling_scaling_penalties - import sglang.srt.sampling.penaltylib as penaltylib from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) logger = logging.getLogger(__name__) @@ -386,14 +383,7 @@ def apply_logits_bias(self, logits: torch.Tensor): # repetition if self.scaling_penalties is not None: - if is_cuda: - logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties) - else: - logits[:] = torch.where( - logits > 0, - logits / self.scaling_penalties, - logits * self.scaling_penalties, - ) + apply_scaling_penalties(logits, self.scaling_penalties) # Apply regex vocab_mask if self.vocab_mask is not None: