Skip to content

Commit

Permalink
only use dmmv for supported types
Browse files Browse the repository at this point in the history
  • Loading branch information
slaren committed Aug 1, 2024
1 parent f2596b3 commit 8e68836
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1885,7 +1885,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);

bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/dmmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -672,3 +672,12 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}

bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) {
return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 ||
src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 ||
src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K ||
src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K ||
src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K ||
src0_type == GGML_TYPE_F16;
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-cuda/dmmv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);

bool ggml_cuda_dmmv_type_supported(ggml_type src0_type);

0 comments on commit 8e68836

Please sign in to comment.