diff --git a/setup.py b/setup.py index 67a8d2e576..0acfe3aa95 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def use_debug_mode(): import torch from torch.utils.cpp_extension import ( CUDA_HOME, + ROCM_HOME, IS_WINDOWS, BuildExtension, CppExtension, @@ -203,22 +204,31 @@ def get_extensions(): print( "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" ) - if CUDA_HOME is None and torch.cuda.is_available(): + if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda: print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") print( "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" ) + if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip: + print("ROCm is not available. Skipping compilation of ROCm extensions") + print( + "If you'd like to compile ROCm extensions locally please install ROCm" + ) use_cuda = torch.cuda.is_available() and CUDA_HOME is not None - extension = CUDAExtension if use_cuda else CppExtension + use_rocm = torch.cuda.is_available() and ROCM_HOME is not None + extension = CUDAExtension if (use_cuda or use_rocm) else CppExtension + + nvcc_args = [ + "-O3" if not debug_mode else "-O0", + "-t=0", + ] + rocm_args = ["-O3" if not debug_mode else "-O0"] extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ], + "nvcc": nvcc_args if use_cuda else rocm_args } if not IS_WINDOWS: @@ -240,17 +250,43 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") + if use_rocm: + # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 + found = False + print("ROCM_HOME", ROCM_HOME) + hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))) + print("hipblaslt_headers", hipblaslt_headers) + for header in hipblaslt_headers: + with open(header) as f: + if "HIPBLASLT_ORDER_COL16" in f.read(): + found = True + break + if found: + extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") + print("hipblaslt found extended col order enums") + else: + print("hipblaslt does not have extended col order enums") + this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + extensions_rocm_dir = os.path.join(extensions_dir, "rocm") cuda_sources = list( glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) ) + rocm_sources = list( + glob.glob(os.path.join(extensions_rocm_dir, "**/*.hip"), recursive=True) + ) + rocm_sources += list( + glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True) + ) if use_cuda: sources += cuda_sources + if use_rocm: + sources += rocm_sources use_cutlass = False if use_cuda and not IS_WINDOWS: diff --git a/test/test_ops.py b/test/test_ops.py index 54efefb026..5b380b28eb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,6 +20,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode +IS_CUDA = torch.cuda.is_available() and torch.version.cuda +IS_ROCM = torch.cuda.is_available() and torch.version.hip + if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" @@ -49,7 +52,7 @@ def _create_floatx_inputs( fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) def test_quant_llm_linear(self, ebits, mbits, dtype): @@ -79,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): test_utils=test_utils, ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -136,7 +139,7 @@ def make_test_id(param): return f"tiles_{param}" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): @@ -154,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): @@ -200,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): return dq.reshape(n, k) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -268,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -334,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( assert diff_op_ao < 1e-1 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -445,7 +448,7 @@ def reshape_w(w): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -535,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -614,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) +@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available") +def test_swizzle_mm(): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AT_LEAST_2_5: + test_utils.append("test_aot_dispatch_dynamic") + + mat1 = torch.randint(0, 16, dtype=torch.float, size=(16,32), device="cuda") + mat2 = torch.randint(0, 16, dtype=torch.float, size=(32,16), device="cuda") + + opcheck( + torch.ops.torchao.swizzle_mm, + (mat1, mat2), + test_utils=test_utils, + ) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/__init__.py b/torchao/__init__.py index 11716da62e..9db71b1471 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -55,12 +55,13 @@ quantize_, ) -from . import dtypes, testing +from . import dtypes, swizzle, testing __all__ = [ "dtypes", "autoquant", "quantize_", + "swizzle", "testing", "ops", ] diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp new file mode 100644 index 0000000000..294cd19e05 --- /dev/null +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -0,0 +1,468 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +using at::Scalar; +using at::Tensor; +using at::TensorArg; +using c10::IntArrayRef; + +// +// copied from aten/src/ATen/cuda/CUDABlas.cpp +// +namespace { + +static hipblasOperation_t _cublasOpFromChar(char op) { + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + TORCH_CHECK(false, + "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static void _cublasAdjustLdLevel3( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + int64_t* lda, + int64_t* ldb, + int64_t* ldc) { + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); + + // Note: leading dimensions generally are checked that they are > 0 + // and at least as big the result requires (even if the value won't + // be used). + if (n <= 1) + *ldc = std::max(m, 1); + + if (transa_) { + if (m <= 1) + *lda = std::max(k, 1); + } else { + if (k <= 1) + *lda = std::max(m, 1); + } + + if (transb_) { + if (k <= 1) + *ldb = std::max(n, 1); + } else { + if (n <= 1) + *ldb = std::max(k, 1); + } +} + +// Following the pattern of CuSparseDescriptor +// Defined here for now because this is the only place cublas_lt interface is +// used but can be moved to a header once cublas_lt interface is used in +// multiple places. +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + // NOLINTNEXTLINE(bugprone-sizeof-expression) + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value))); + } +}; + +class HipBlasLtMatrixLayout : public HipBlasLtDescriptor< + hipblasLtMatrixLayoutOpaque_t, + &hipblasLtMatrixLayoutDestroy> { + public: + HipBlasLtMatrixLayout( + hipDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld, + bool t = false) { + hipblasLtMatrixLayout_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatrixLayoutAttribute_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +class HipBlasLtMatmulPreference : public HipBlasLtDescriptor< + hipblasLtMatmulPreferenceOpaque_t, + &hipblasLtMatmulPreferenceDestroy> { + public: + HipBlasLtMatmulPreference() { + hipblasLtMatmulPreference_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulPreferenceCreate(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulPreferenceAttributes_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +static size_t _parseChosenWorkspaceSize() { + auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); +#ifdef USE_ROCM + if (!val.has_value()) { + // accept either env var + val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + } + size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ +#else + size_t workspace_size = 1024; /* default size in KiB according to #73328 */ +#endif + + if (val.has_value()) { + try { + workspace_size = std::stoi(val.value()); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +static size_t _getWorkspaceSize() { + static size_t workspace_size = _parseChosenWorkspaceSize(); + return workspace_size; +} + +} // namespace + +// +// copied from aten/src/ATen/native/cuda/Blas.cpp +// +namespace { + +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, true); + } + + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, true); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, true); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +struct cublasCommonArgs { + cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) { + bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false; + result = prepare_matrix_for_cublas(c, transpose_result); + mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); + matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); + auto mat1_sizes = mat1.sizes(); + auto mat2_sizes = mat2.sizes(); + if (transpose_result) { + transpose_mat1 = !transpose_mat1; + transpose_mat2 = !transpose_mat2; + mat1_sizes = mata->sizes(); + mat2_sizes = matb->sizes(); + } + + m = mat1_sizes[transpose_result ? 1 : 0]; + k = mat1_sizes[transpose_result ? 0 : 1]; + n = mat2_sizes[transpose_result ? 0 : 1]; + lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0); + ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0); + result_ld = result->stride(transpose_result ? 0 : 1); + transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n'; + transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n'; + } + char transa, transb; + int64_t m, n, k; + int64_t lda, ldb, result_ld; + c10::MaybeOwned mata, matb, result; +}; + +} // namespace + +template +inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { + hipDataType abcType = HIP_R_32F; + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + hipDataType scaleType = HIP_R_32F; + if constexpr (std::is_same_v) { + abcType = HIP_R_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_R_64F; + } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_C_64F; + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_32F; + scaleType = HIP_C_32F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16BF; + } else { + static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented"); + } + + hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + hipblasOperation_t opa = _cublasOpFromChar(transa); + hipblasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + + HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + HipBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Cdesc(abcType, m, n, ldc); +#ifdef HIPBLASLT_HAS_ORDER_COL16 + if (mat1_is_swizzled) { + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } + if (mat2_is_swizzled) { + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } +#endif + + if (num_batches > 1) { + int num_batches_as_int = static_cast(num_batches); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec); + } + + HipBlasLtMatmulPreference preference; + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); + preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + +#ifndef USE_ROCM + uint32_t a_alignment = _getAlignment(reinterpret_cast(a)); + uint32_t b_alignment = _getAlignment(reinterpret_cast(b)); + uint32_t c_alignment = _getAlignment(reinterpret_cast(c)); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); +#endif + + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + + hipblasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); + } + + hipblasStatus_t cublasStatus = hipblasLtMatmul( + ltHandle, + computeDesc.descriptor(), + &alpha, + a, + Adesc.descriptor(), + b, + Bdesc.descriptor(), + &beta, + c, + Cdesc.descriptor(), + c, + Cdesc.descriptor(), + &heuristicResult.algo, + workspace.mutable_data_ptr(), + workspaceSize, + at::hip::getCurrentHIPStreamMasqueradingAsCUDA()); + TORCH_CHECK( + cublasStatus == HIPBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling hipblasLtMatmul with transpose_mat1 ", + (opa == HIPBLAS_OP_T), + " transpose_mat2 ", + (opb == HIPBLAS_OP_T), + " m ", + m, + " n ", + n, + " k ", + k, + " lda ", + lda, + " ldb ", + ldb, + " ldc ", + ldc, + " abcType ", + abcType, + " computeType ", + computeType, + " scaleType ", + scaleType); +} + + +template +inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { + // forward to bgemm implementation but set strides and batches to 0 + bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0, mat1_is_swizzled, mat2_is_swizzled); +} + + +Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) { + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ); + + // NOLINTNEXTLINE(*c-array*) + TensorArg targs[]{{mat1, "mat1", 0}, {mat2, "mat2", 1}}; + checkAllSameGPU(__func__, targs); + + Tensor meta_mat1 = mat1.to("meta"); + Tensor meta_mat2 = mat2.to("meta"); + Tensor meta_result = at::mm(mat1, mat2); + Tensor result = at::empty_like(meta_result, mat1.device()); + at::ScalarType scalar_type = result.scalar_type(); + + cublasCommonArgs args(mat1, mat2, result); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "addmm_cuda", + [&] { + using opmath_t = at::opmath_type; + opmath_t alpha_val = opmath_t(1.0); + opmath_t beta_val = opmath_t(0.0); + const scalar_t* mat1_ptr = args.mata->const_data_ptr(); + const scalar_t* mat2_ptr = args.matb->const_data_ptr(); + scalar_t* result_ptr = args.result->mutable_data_ptr(); + gemm_hipblaslt( + args.transa, + args.transb, + args.m, + args.n, + args.k, + alpha_val, + mat1_ptr, + args.lda, + mat2_ptr, + args.ldb, + beta_val, + result_ptr, + args.result_ld, + mat1_is_swizzled, + mat2_is_swizzled); + }); + + return result; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::swizzle_mm", &swizzle_mm); +} diff --git a/torchao/ops.py b/torchao/ops.py index 8b573876f2..179732776a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -25,6 +25,9 @@ lib.define( "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define( + "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" +) def register_custom_op(name): @@ -592,3 +595,18 @@ def _( bias: Tensor, ) -> Tensor: return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def swizzle_mm(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool) -> Tensor: + """ + Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. + + """ + return torch.ops.torchao.swizzle_mm.default( + mat1, mat2, mat1_is_swizzled, mat2_is_swizzled + ) + + +@register_custom_op("torchao::swizzle_mm") +def _(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool) -> Tensor: + return mat1.new_empty(mat1.shape[0], mat2.shape[1]) diff --git a/torchao/swizzle/__init__.py b/torchao/swizzle/__init__.py new file mode 100644 index 0000000000..b5135532ef --- /dev/null +++ b/torchao/swizzle/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .swizzle_tensor import SwizzleTensor + +__all__ = [ + "SwizzleTensor" +] diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py new file mode 100644 index 0000000000..1128115633 --- /dev/null +++ b/torchao/swizzle/swizzle_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, Tuple + +import torch + +import torchao.ops +from torchao.swizzle.swizzle_tensor import SwizzleTensor + +aten = torch.ops.aten +SWIZZLE_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the swizzle op table""" + + def decorator(func): + for op in aten_ops: + SWIZZLE_OPS_TABLE[op] = func + return func + + return decorator + + +@implements([aten.mm.default]) +def swizzle_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + if torch.is_floating_point(a) and torch.is_floating_point(b): + a_is_swizzled = False + b_is_swizzled = False + if isinstance(a, SwizzleTensor): + a = a.as_tensor() + a_is_swizzled = True + if isinstance(b, SwizzleTensor): + b = b.as_tensor() + b_is_swizzled = True + tensor_out = torchao.ops.swizzle_mm(a, b, a_is_swizzled, b_is_swizzled) + else: + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + tensor_out = aten_op(a, b, **kwargs) + return tensor_out + + +@implements([aten.bmm.default]) +def swizzle_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + return aten_op(a, b, **kwargs) + + +@implements([aten.addmm.default]) +def swizzle_addmm(aten_op, args, kwargs=None): + bias = args[0] + a = args[1] + b = args[2] + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + return aten_op(bias, a, b, args[3:], **kwargs) + + diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py new file mode 100644 index 0000000000..0b1cb18643 --- /dev/null +++ b/torchao/swizzle/swizzle_tensor.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils._pytree import tree_map + +# copied from float8_utils.py +def _get_min_alignment(size: int, alignment_value: int) -> int: + return (1 + ((size - 1) // alignment_value)) * alignment_value + +class SwizzleTensor(torch.Tensor): + """ + A Python-only swizzled tensor subclass. + + Intended usage of this abstraction: + Swizzle weight Tensor to avoid LDS use during GEMMs on ROCm hardware. + """ + + def __new__( + cls, + original: torch.Tensor, + ): + wrapper = torch.empty_like(original, device="meta") + return torch.Tensor._make_subclass(cls, wrapper) + + def __init__(self, original): + assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) + if original.ndim == 2: + M, K = original.shape + B = 0 + if original.ndim == 3: + B, M, K = original.shape + alignedM = _get_min_alignment(M, 16) + alignedK = _get_min_alignment(K, 32) + paddedM = alignedM - M + paddedK = alignedK - K + x = torch.nn.functional.pad(original, (0, paddedK, 0, paddedM), "constant", 0) + if original.ndim == 2: + x = x.view(alignedM//16, 16, alignedK//32, 4, 8) + x = x.permute(0, 2, 3, 1, 4) + if original.ndim == 3: + x = x.view(B, alignedM//16, 16, alignedK//32, 4, 8) + x = x.permute(0, 1, 3, 4, 2, 5) + self.x = x.contiguous() + self.B = B + self.M = M + self.K = K + self.alignedM = alignedM + self.alignedK = alignedK + self.paddedM = paddedM + self.paddedK = paddedK + self.original_ndim = original.ndim + + def __repr__(self): + return f"{self.__class__.__name__}(original={self.unswizzle()})" + + def unswizzle(self): + if self.original_ndim == 2: + undone = self.x.permute(0, 3, 1, 2, 4).contiguous() + undone = undone.reshape(self.alignedM, self.alignedK) + undone = undone[0:self.M, 0:self.K] + return undone.reshape(self.M, self.K) + if self.original_ndim == 3: + undone = self.x.permute(0, 1, 4, 2, 3, 5).contiguous() + undone = undone.reshape(self.B, self.alignedM, self.alignedK) + undone = undone[0:self.B, 0:self.M, 0:self.K] + return undone.reshape(self.B, self.M, self.K) + + def as_tensor(self): + if self.original_ndim == 2: + return self.x.reshape(self.alignedM, self.alignedK) + if self.original_ndim == 3: + return self.x.reshape(self.B, self.alignedM, self.alignedK) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Lazy import to avoid circular dependency + from torchao.swizzle.swizzle_ops import SWIZZLE_OPS_TABLE + if func in SWIZZLE_OPS_TABLE: + return SWIZZLE_OPS_TABLE[func](func, args, kwargs) + + def unwrap(e): + return e.unswizzle() if isinstance(e, SwizzleTensor) else e + + def wrap(e): + return SwizzleTensor(e) if isinstance(e, torch.Tensor) else e + + return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + + # Do not force the SwizzleTensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl