Skip to content

Commit

Permalink
CUDA: stream-k decomposition for MMQ (ggerganov#8018)
Browse files Browse the repository at this point in the history
* CUDA: stream-k decomposition for MMQ

* fix undefined memory reads for small matrices
  • Loading branch information
JohannesGaessler authored Jun 20, 2024
1 parent 2075a66 commit d50f889
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 112 deletions.
2 changes: 1 addition & 1 deletion ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> &
}

const int cc = ggml_cuda_info().devices[id].cc;
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc)));
row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
}
return row_rounding;
}
Expand Down
4 changes: 2 additions & 2 deletions ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) {
}

// Round rows to this value for --split-mode row:
static int get_mmq_y_host(const int cc, const int mmq_x) {
return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
static int get_mmq_y_host(const int cc) {
return cc >= CC_VOLTA ? 128 : 64;
}

//////////////////////
Expand Down
20 changes: 10 additions & 10 deletions ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q(

switch (src0->type) {
case GGML_TYPE_Q4_0:
mul_mat_q_case<GGML_TYPE_Q4_0>(args, stream);
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_q_case<GGML_TYPE_Q4_1>(args, stream);
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_q_case<GGML_TYPE_Q5_0>(args, stream);
mul_mat_q_case<GGML_TYPE_Q5_0>(ctx, args, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_q_case<GGML_TYPE_Q5_1>(args, stream);
mul_mat_q_case<GGML_TYPE_Q5_1>(ctx, args, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(args, stream);
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(args, stream);
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_q_case<GGML_TYPE_Q3_K>(args, stream);
mul_mat_q_case<GGML_TYPE_Q3_K>(ctx, args, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_q_case<GGML_TYPE_Q4_K>(args, stream);
mul_mat_q_case<GGML_TYPE_Q4_K>(ctx, args, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_q_case<GGML_TYPE_Q5_K>(args, stream);
mul_mat_q_case<GGML_TYPE_Q5_K>(ctx, args, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_q_case<GGML_TYPE_Q6_K>(args, stream);
mul_mat_q_case<GGML_TYPE_Q6_K>(ctx, args, stream);
break;
default:
GGML_ASSERT(false);
Expand Down
Loading

0 comments on commit d50f889

Please sign in to comment.