diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu index 785f1a09c1900..2defb5f7e919a 100644 --- a/csrc/quantization/gptq/q_gemm.cu +++ b/csrc/quantization/gptq/q_gemm.cu @@ -30,11 +30,11 @@ namespace gptq { #define MAX_ALT_GEMM_ROWS 8 #define THREADS_X 32 #define THREADS_Y 32 -#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) +#define DIVIDE(x, size) (((x) + (size)-1) / (size)) #if defined(USE_ROCM) #include -__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( +__host__ __forceinline__ hipblasStatus_t gemm1( hipblasHandle_t handle, hipblasOperation_t transA, hipblasOperation_t transB, int m, int n, int k, const half* alpha, const half* AP, int lda, const half* BP, int ldb, const half* beta, @@ -46,7 +46,99 @@ __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm( reinterpret_cast(beta), reinterpret_cast(CP), ldc); } - #define hipblasHgemm __compat_hipblasHgemm + + #include + #include + + #define HIPBLASLT_CHECK(EXPR) \ + do { \ + hipblasStatus_t status_ = (EXPR); \ + if (status_ != HIPBLAS_STATUS_SUCCESS) \ + printf("hipblaslt error: %s\n", hipblasStatusToString(status_)); \ + } while (0) + +__host__ __forceinline__ hipblasStatus_t gemm2( + hipblasLtHandle_t handle, hipblasOperation_t transA, + hipblasOperation_t transB, int m, int n, int k, const half* alpha, + const half* AP, int lda, const half* BP, int ldb, const half* beta, + half* CP, int ldc) { + // Create operation descriptor + hipblasLtMatmulDesc_t matmulDesc; + hipblasStatus_t status = + hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32F, HIP_R_32F); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( + matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &transA, sizeof(transA))); + HIPBLASLT_CHECK(hipblasLtMatmulDescSetAttribute( + matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &transB, sizeof(transB))); + + // Set matrix layout descriptors + hipblasLtMatrixLayout_t Adesc, Bdesc, Cdesc, Ddesc; + status = hipblasLtMatrixLayoutCreate(&Adesc, HIP_R_16F, m, + k, lda); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + status = hipblasLtMatrixLayoutCreate(&Bdesc, HIP_R_16F, k, + n, ldb); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + status = hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_16F, m, n, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + // Ddesc is the same as Cdesc in this use case + status = hipblasLtMatrixLayoutCreate(&Ddesc, HIP_R_16F, m, n, ldc); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + // Create matmul preference and get heuristic result + hipblasLtMatmulPreference_t preference; + status = hipblasLtMatmulPreferenceCreate(&preference); + if (status != HIPBLAS_STATUS_SUCCESS) { + return status; + } + + hipblasLtMatmulHeuristicResult_t heuristicResult; + int returnedAlgoCount = 0; + status = hipblasLtMatmulAlgoGetHeuristic( + handle, matmulDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedAlgoCount); + if (status != HIPBLAS_STATUS_SUCCESS || returnedAlgoCount == 0) { + return status; + } + + // Perform the matrix multiplication + status = + hipblasLtMatmul(handle, matmulDesc, alpha, AP, Adesc, BP, Bdesc, beta, CP, + Cdesc, CP, Ddesc, &heuristicResult.algo, nullptr, 0, 0); + + if (status != HIPBLAS_STATUS_SUCCESS) { + printf("hipblas lt matmul failed\n"); + return status; + } + + // Clean up resources + hipblasLtMatmulPreferenceDestroy(preference); + hipblasLtMatrixLayoutDestroy(Adesc); + hipblasLtMatrixLayoutDestroy(Bdesc); + hipblasLtMatrixLayoutDestroy(Cdesc); + hipblasLtMatrixLayoutDestroy(Ddesc); + hipblasLtMatmulDescDestroy(matmulDesc); + + return status; +} + + // Replace hipblasHgemm with __compat_hipblasHbemm + #define hipblasHgemm gemm2 // Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. #define rocblas_operation_none HIPBLAS_OP_N @@ -1493,15 +1585,7 @@ void gemm_half_q_half_cuda(cublasHandle_t cublas_handle, const half* a, const half* b_gptq_scales, const int* b_g_idx, half* c, half* temp_dq, int size_m, int size_n, int size_k, int groups, bool use_exllama, int bit) { - bool use_reconstruct; - if (use_exllama) { - use_reconstruct = ((bit == 8 && size_m > MAX_Q_GEMM_ROWS_8BIT) || - (bit != 8 && size_m > MAX_Q_GEMM_ROWS)); - } else { - // The 2/3-bit kernels are somehow slower than dequant + gemm baseline, so - // we disabled them for now. - use_reconstruct = (bit < 4 || size_m > MAX_ALT_GEMM_ROWS); - } + bool use_reconstruct = true; if (use_reconstruct) { // Reconstruct FP16 matrix, then cuBLAS if (use_exllama) {