From c5d9effb49649db80a52caf5c0626de6f342f526 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 24 Jan 2025 21:02:43 +0100 Subject: [PATCH] CUDA: fix FP16 cuBLAS GEMM (#11396) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fb3d9e2d92e09..fbe889a01221b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, + &alpha, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, &beta, dst_dd_i, CUDA_R_32F, ldc, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); @@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK( cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N, row_diff, src1_ncols, ne10, - &alpha_f16, src0_ptr, CUDA_R_16F, ne00, - src1_ptr, CUDA_R_16F, ne10, - &beta_f16, dst_dd_i, CUDA_R_16F, ldc, + &alpha_f16, src0_ptr, CUDA_R_16F, ne00, + src1_ptr, CUDA_R_16F, ne10, + &beta_f16, dst_f16.get(), CUDA_R_16F, ldc, CUBLAS_COMPUTE_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));