Skip to content

Commit

Permalink
ggml : add asserts for type conversion in fattn kernels (ggerganov#9971)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov authored Oct 21, 2024
1 parent d5ebd79 commit f594bc8
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ static ggml_type kv_cache_type_from_str(const std::string & s) {
return GGML_TYPE_Q5_1;
}

throw std::runtime_error("Invalid cache type: " + s);
throw std::runtime_error("Unsupported cache type: " + s);
}

struct llama_context_params common_context_params_to_llama(const common_params & params) {
Expand All @@ -1047,7 +1047,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.cpuparams.n_threads;
cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ?
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
params.cpuparams.n_threads : params.cpuparams_batch.n_threads;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
Expand Down
6 changes: 5 additions & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,9 @@ struct ggml_logger_state {
static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};

static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
if (format == NULL)
if (format == NULL) {
return;
}
va_list args_copy;
va_copy(args_copy, args);
char buffer[128];
Expand Down Expand Up @@ -15723,6 +15724,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;

GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
GGML_ASSERT(v_to_float && "fattn: unsupported V-type");

// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
Expand Down
2 changes: 1 addition & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19243,7 +19243,7 @@ struct llama_context * llama_new_context_with_model(
params.flash_attn = false;
}

if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
return nullptr;
}
Expand Down

0 comments on commit f594bc8

Please sign in to comment.