Skip to content

Commit

Permalink
[hegeman/AWQ] Torch Int-4 AWQ Dequantization and Configuration Options (
Browse files Browse the repository at this point in the history
  • Loading branch information
hegemanjw4amd authored Aug 21, 2024
1 parent 280db50 commit 4e9830e
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 18 deletions.
5 changes: 3 additions & 2 deletions tests/kernels/test_awq_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import torch

from vllm.model_executor.layers.quantization.awq import torch_awq_dequantize
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton, awq_gemm_triton)

Expand Down Expand Up @@ -76,7 +77,7 @@ def awq_gemm_torch(input: torch.Tensor, qweight: torch.Tensor,
print(f"awq_gemm_torch:input_rows = {input_rows} input_cols = {input_cols}"
f" qweight_rows = {qweight_rows} qweight_cols = {qweight_cols}"
f" scales_rows = {scales_rows} scales_cols = {scales_cols}")
weights, zeros = awq_dequantize_torch(qweight, scales, qzeros)
weights = torch_awq_dequantize(qweight, scales, qzeros)
return torch.matmul(input, weights)


Expand Down Expand Up @@ -123,7 +124,7 @@ def test_dequantize(qweight_rows, qweight_cols):
print("Any infs in triton result? -->"
f"{torch.any(torch.isinf(iweights_triton))}")

iweights_torch, _ = awq_dequantize_torch(qweight, scales, zeros)
iweights_torch = torch_awq_dequantize(qweight, scales, zeros)
print(f"Torch result:iweights_torch = {iweights_torch}")

diff = iweights_torch - iweights_triton
Expand Down
24 changes: 12 additions & 12 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
print(f"awq_dequantize:qweight.shape = {qweight.shape}"
f"scales = {scales.shape},"
f"zeros = {zeros.shape},"
f"split_k_iters = {split_k_iters},"
f"thx = {thx}"
f"thy = {thy}")
# print(f"awq_dequantize:qweight.shape = {qweight.shape}"
# f"scales = {scales.shape},"
# f"zeros = {zeros.shape},"
# f"split_k_iters = {split_k_iters},"
# f"thx = {thx}"
# f"thy = {thy}")
if is_hip() and envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_dequantize_triton)
Expand All @@ -153,12 +153,12 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
if input.shape[0] > 1:
print(f"awq_gemm:input.shape = {input.shape},"
f"qweight = {qweight.shape},"
f"qzeros = {qzeros.shape},"
f"scales.shape = {scales.shape},"
f"split_k_iters = {split_k_iters}")
# if input.shape[0] > 1:
# print(f"awq_gemm:input.shape = {input.shape},"
# f"qweight = {qweight.shape},"
# f"qzeros = {qzeros.shape},"
# f"scales.shape = {scales.shape},"
# f"split_k_iters = {split_k_iters}")
if is_hip() and envs.VLLM_USE_TRITON_AWQ:
from vllm.model_executor.layers.quantization.awq_triton import (
awq_gemm_triton)
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")

self.use_naive_attn = False
self.use_naive_attn = envs.VLLM_USE_SDPA_ATTENTION # Default False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN
if self.use_triton_flash_attn:
Expand All @@ -306,7 +306,7 @@ def __init__(

if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
logger.debug("Using naive (SDPA) attention in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
Expand Down
18 changes: 18 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
VLLM_INSTANCE_ID: Optional[str] = None
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_ROCM_PREFER_TORCH: bool = False
VLLM_ROCM_PREFER_TRITON: bool = True
VLLM_USE_SDPA_ATTENTION: bool = False
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
Expand Down Expand Up @@ -136,6 +139,21 @@
"LD_LIBRARY_PATH":
lambda: os.environ.get("LD_LIBRARY_PATH", None),

# flag to tell vllm to prefer torch on ROCm
"VLLM_ROCM_PREFER_TORCH":
lambda: (os.environ.get("VLLM_ROCM_PREFER_TORCH", "False").lower() in
("true", "1")),

# flag to tell vllm to prefer triton on ROCm
"VLLM_ROCM_PREFER_TRITON":
lambda: (os.environ.get("VLLM_ROCM_PREFER_TRITON", "True").lower() in
("true", "1")),

# flag to control if vllm should use naive scaled dot-product attention
"VLLM_USE_SDPA_ATTENTION":
lambda: (os.environ.get("VLLM_USE_SDPA_ATTENTION", "False").lower() in
("true", "1")),

# flag to control if vllm should use triton flash attention
"VLLM_USE_TRITON_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
Expand Down
53 changes: 51 additions & 2 deletions vllm/model_executor/layers/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.parameter import Parameter

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -164,12 +165,60 @@ def apply(self,
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256

if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
prefer_torch = envs.VLLM_ROCM_PREFER_TORCH
prefer_triton = envs.VLLM_ROCM_PREFER_TRITON

if (FP16_MATMUL_HEURISTIC_CONDITION
or (prefer_torch and not prefer_triton)):
if prefer_triton:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
else:
out = torch_awq_dequantize(qweight, scales, qzeros)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)


def torch_awq_dequantize(qweights: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor) -> torch.Tensor:
reverse_awq_func_desc = torch.tensor([0, 16, 4, 20, 8, 24, 12, 28],
dtype=torch.int32,
device=qweights.device)
if qzeros is None:
qzeros = torch.zeros_like(qweights)

while qweights.dim() < 2:
qweights = torch.unsqueeze(qweights, 0)
while qzeros.dim() < 2:
qzeros = torch.unsqueeze(qzeros, 0)
while scales.dim() < 2:
scales = torch.unsqueeze(scales, 0)

rows = qweights.size(-2)
group_size_zeros = rows // qzeros.size(-2)
group_size_scales = rows // scales.size(-2)

qweights_shape = list(qweights.shape)
qweights_shape[-1] *= 8
qzeros_shape = list(qzeros.shape)
qzeros_shape[-1] *= 8

qweights = torch.unsqueeze(qweights, -1)
qzeros = torch.unsqueeze(qzeros, -1)

unpacked_weights = torch.bitwise_right_shift(qweights,
reverse_awq_func_desc)
unpacked_weights = torch.bitwise_and(unpacked_weights, 0xf)
unpacked_weights = unpacked_weights.to(torch.int8).view(qweights_shape)

unpacked_zeros = torch.bitwise_right_shift(qzeros, reverse_awq_func_desc)
unpacked_zeros = torch.bitwise_and(unpacked_zeros, 0xf)
unpacked_zeros = unpacked_zeros.to(torch.int8).view(qzeros_shape)
unpacked_zeros = unpacked_zeros.repeat_interleave(group_size_zeros, dim=-2)

functional_scales = scales.repeat_interleave(group_size_scales, dim=-2)
return (unpacked_weights - unpacked_zeros) * functional_scales

0 comments on commit 4e9830e

Please sign in to comment.