Skip to content

Commit

Permalink
[AMD][Quantization] Add TritonScaledMMLinearKernel since int8 is brok…
Browse files Browse the repository at this point in the history
…en for AMD (vllm-project#12282)

Signed-off-by: Randall Smith <[email protected]>
  • Loading branch information
rasmith authored Jan 23, 2025
1 parent aea9436 commit 68c4421
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
17 changes: 17 additions & 0 deletions tests/kernels/test_triton_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,23 @@ def get_8bit_types():
return types


# This test is to check regressions for int8 support on ROCm.
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="Should only run on ROCm")
def test_rocm_compressed_tensors_w8a8(vllm_runner, example_prompts, model_path,
max_tokens, num_logprobs):
dtype = "bfloat16"

with vllm_runner(model_path, dtype=dtype) as vllm_model:
vllm_model.generate_greedy_logprobs(example_prompts, max_tokens,
num_logprobs)


@pytest.mark.parametrize("M", [1, 33, 64, 512])
@pytest.mark.parametrize("N", [256, 971, 20486])
@pytest.mark.parametrize("K", [128, 496, 1024])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
CutlassScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel, ScaledMMLinearLayerConfig)
# from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
# TritonScaledMMLinear)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel)
from vllm.platforms import PlatformEnum, current_platform
Expand All @@ -15,9 +15,7 @@
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
# TODO(rob): Create TritonScaledMMLinear kernel. ROCM will
# incorrectly attempt to run AZP models if prompted to.
PlatformEnum.ROCM: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional, Tuple

import torch

from vllm.platforms import current_platform

from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):

@classmethod
def get_min_capability(cls) -> int:
return 75

@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
return (
False,
"TritonScaledMMLinearKernel requires Triton which is not " +
"currently supported on CPU.")
if not c.input_symmetric:
return (False,
"TritonScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)

def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return super().apply_weights(layer, x, bias)

0 comments on commit 68c4421

Please sign in to comment.