diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py index 55ab0451d828f..ba92557a7f6b7 100644 --- a/tests/kernels/test_awq_triton.py +++ b/tests/kernels/test_awq_triton.py @@ -7,6 +7,7 @@ import pytest import torch +from vllm.model_executor.layers.quantization.awq import torch_awq_dequantize from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton, awq_gemm_triton) @@ -76,7 +77,7 @@ def awq_gemm_torch(input: torch.Tensor, qweight: torch.Tensor, print(f"awq_gemm_torch:input_rows = {input_rows} input_cols = {input_cols}" f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}" f" scales_rows = {scales_rows} scales_cols = {scales_cols}") - weights, zeros = awq_dequantize_torch(qweight, scales, qzeros) + weights = torch_awq_dequantize(qweight, scales, qzeros) return torch.matmul(input, weights) @@ -123,7 +124,7 @@ def test_dequantize(qweight_rows, qweight_cols): print("Any infs in triton result? -->" f"{torch.any(torch.isinf(iweights_triton))}") - iweights_torch, _ = awq_dequantize_torch(qweight, scales, zeros) + iweights_torch = torch_awq_dequantize(qweight, scales, zeros) print(f"Torch result:iweights_torch = {iweights_torch}") diff = iweights_torch - iweights_triton diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a5d4a9386e1f4..9de7e666b84e6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -131,12 +131,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: - print(f"awq_dequantize:qweight.shape = {qweight.shape}" - f"scales = {scales.shape}," - f"zeros = {zeros.shape}," - f"split_k_iters = {split_k_iters}," - f"thx = {thx}" - f"thy = {thy}") + # print(f"awq_dequantize:qweight.shape = {qweight.shape}" + # f"scales = {scales.shape}," + # f"zeros = {zeros.shape}," + # f"split_k_iters = {split_k_iters}," + # f"thx = {thx}" + # f"thy = {thy}") if is_hip() and envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( awq_dequantize_triton) @@ -153,12 +153,12 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: - if input.shape[0] > 1: - print(f"awq_gemm:input.shape = {input.shape}," - f"qweight = {qweight.shape}," - f"qzeros = {qzeros.shape}," - f"scales.shape = {scales.shape}," - f"split_k_iters = {split_k_iters}") + # if input.shape[0] > 1: + # print(f"awq_gemm:input.shape = {input.shape}," + # f"qweight = {qweight.shape}," + # f"qzeros = {qzeros.shape}," + # f"scales.shape = {scales.shape}," + # f"split_k_iters = {split_k_iters}") if is_hip() and envs.VLLM_USE_TRITON_AWQ: from vllm.model_executor.layers.quantization.awq_triton import ( awq_gemm_triton) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 894c9e9dc6554..8a2f7ac06d0e1 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -283,7 +283,7 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - self.use_naive_attn = False + self.use_naive_attn = envs.VLLM_USE_SDPA_ATTENTION # Default False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: @@ -306,7 +306,7 @@ def __init__( if self.use_naive_attn: self.attn_func = _naive_attention - logger.debug("Using naive attention in ROCmBackend") + logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" diff --git a/vllm/envs.py b/vllm/envs.py index abcb87fcf8fff..ac3a665a7420f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -8,6 +8,9 @@ VLLM_INSTANCE_ID: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None + VLLM_ROCM_PREFER_TORCH: bool = False + VLLM_ROCM_PREFER_TRITON: bool = True + VLLM_USE_SDPA_ATTENTION: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True @@ -136,6 +139,21 @@ "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), + # flag to tell vllm to prefer torch on ROCm + "VLLM_ROCM_PREFER_TORCH": + lambda: (os.environ.get("VLLM_ROCM_PREFER_TORCH", "False").lower() in + ("true", "1")), + + # flag to tell vllm to prefer triton on ROCm + "VLLM_ROCM_PREFER_TRITON": + lambda: (os.environ.get("VLLM_ROCM_PREFER_TRITON", "True").lower() in + ("true", "1")), + + # flag to control if vllm should use naive scaled dot-product attention + "VLLM_USE_SDPA_ATTENTION": + lambda: (os.environ.get("VLLM_USE_SDPA_ATTENTION", "False").lower() in + ("true", "1")), + # flag to control if vllm should use triton flash attention "VLLM_USE_TRITON_FLASH_ATTN": lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index f4fc7ce020e95..e47b8542f269b 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,6 +3,7 @@ import torch from torch.nn.parameter import Parameter +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( @@ -164,8 +165,15 @@ def apply(self, # num_tokens >= threshold FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - if FP16_MATMUL_HEURISTIC_CONDITION: - out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + prefer_torch = envs.VLLM_ROCM_PREFER_TORCH + prefer_triton = envs.VLLM_ROCM_PREFER_TRITON + + if (FP16_MATMUL_HEURISTIC_CONDITION + or (prefer_torch and not prefer_triton)): + if prefer_triton: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + else: + out = torch_awq_dequantize(qweight, scales, qzeros) out = torch.matmul(reshaped_x, out) else: out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, @@ -173,3 +181,44 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + + +def torch_awq_dequantize(qweights: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor) -> torch.Tensor: + reverse_awq_func_desc = torch.tensor([0, 16, 4, 20, 8, 24, 12, 28], + dtype=torch.int32, + device=qweights.device) + if qzeros is None: + qzeros = torch.zeros_like(qweights) + + while qweights.dim() < 2: + qweights = torch.unsqueeze(qweights, 0) + while qzeros.dim() < 2: + qzeros = torch.unsqueeze(qzeros, 0) + while scales.dim() < 2: + scales = torch.unsqueeze(scales, 0) + + rows = qweights.size(-2) + group_size_zeros = rows // qzeros.size(-2) + group_size_scales = rows // scales.size(-2) + + qweights_shape = list(qweights.shape) + qweights_shape[-1] *= 8 + qzeros_shape = list(qzeros.shape) + qzeros_shape[-1] *= 8 + + qweights = torch.unsqueeze(qweights, -1) + qzeros = torch.unsqueeze(qzeros, -1) + + unpacked_weights = torch.bitwise_right_shift(qweights, + reverse_awq_func_desc) + unpacked_weights = torch.bitwise_and(unpacked_weights, 0xf) + unpacked_weights = unpacked_weights.to(torch.int8).view(qweights_shape) + + unpacked_zeros = torch.bitwise_right_shift(qzeros, reverse_awq_func_desc) + unpacked_zeros = torch.bitwise_and(unpacked_zeros, 0xf) + unpacked_zeros = unpacked_zeros.to(torch.int8).view(qzeros_shape) + unpacked_zeros = unpacked_zeros.repeat_interleave(group_size_zeros, dim=-2) + + functional_scales = scales.repeat_interleave(group_size_scales, dim=-2) + return (unpacked_weights - unpacked_zeros) * functional_scales