Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] preshuffled weight mm #1702

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def use_debug_mode():
import torch
from torch.utils.cpp_extension import (
CUDA_HOME,
ROCM_HOME,
IS_WINDOWS,
BuildExtension,
CppExtension,
Expand Down Expand Up @@ -203,22 +204,31 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if CUDA_HOME is None and torch.cuda.is_available():
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
print("ROCm is not available. Skipping compilation of ROCm extensions")
print(
"If you'd like to compile ROCm extensions locally please install ROCm"
)

use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension
use_rocm = torch.cuda.is_available() and ROCM_HOME is not None
extension = CUDAExtension if (use_cuda or use_rocm) else CppExtension

nvcc_args = [
"-O3" if not debug_mode else "-O0",
"-t=0",
]
rocm_args = ["-O3" if not debug_mode else "-O0"]

extra_link_args = []
extra_compile_args = {
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
],
"nvcc": nvcc_args if use_cuda else rocm_args
}

if not IS_WINDOWS:
Expand All @@ -240,17 +250,43 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

if use_rocm:
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16
found = False
print("ROCM_HOME", ROCM_HOME)
hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")))
print("hipblaslt_headers", hipblaslt_headers)
for header in hipblaslt_headers:
with open(header) as f:
if "HIPBLASLT_ORDER_COL16" in f.read():
found = True
break
if found:
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
print("hipblaslt found extended col order enums")
else:
print("hipblaslt does not have extended col order enums")

this_dir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
extensions_rocm_dir = os.path.join(extensions_dir, "rocm")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)
rocm_sources = list(
glob.glob(os.path.join(extensions_rocm_dir, "**/*.hip"), recursive=True)
)
rocm_sources += list(
glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True)
)

if use_cuda:
sources += cuda_sources
if use_rocm:
sources += rocm_sources

use_cutlass = False
if use_cuda and not IS_WINDOWS:
Expand Down
43 changes: 34 additions & 9 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode

IS_CUDA = torch.cuda.is_available() and torch.version.cuda
IS_ROCM = torch.cuda.is_available() and torch.version.hip

if is_fbcode():
pytest.skip(
"Skipping the test in fbcode since we don't have TARGET file for kernels"
Expand Down Expand Up @@ -49,7 +52,7 @@ def _create_floatx_inputs(
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear(self, ebits, mbits, dtype):
Expand Down Expand Up @@ -79,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
test_utils=test_utils,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
Expand Down Expand Up @@ -136,7 +139,7 @@ def make_test_id(param):
return f"tiles_{param}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
Expand All @@ -154,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):


# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
Expand Down Expand Up @@ -200,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
return dq.reshape(n, k)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -268,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(


# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -334,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
assert diff_op_ao < 1e-1


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -445,7 +448,7 @@ def reshape_w(w):
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
Expand Down Expand Up @@ -535,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
Expand Down Expand Up @@ -614,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
)


@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
def test_swizzle_mm():
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AT_LEAST_2_5:
test_utils.append("test_aot_dispatch_dynamic")

mat1 = torch.randint(0, 16, dtype=torch.float, size=(16,32), device="cuda")
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32,16), device="cuda")

opcheck(
torch.ops.torchao.swizzle_mm,
(mat1, mat2),
test_utils=test_utils,
)


if __name__ == "__main__":
pytest.main(sys.argv)
3 changes: 2 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@
quantize_,
)

from . import dtypes, testing
from . import dtypes, swizzle, testing

__all__ = [
"dtypes",
"autoquant",
"quantize_",
"swizzle",
"testing",
"ops",
]
Expand Down
Loading